summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.codespellrc3
-rw-r--r--.flake87
-rw-r--r--.github/FUNDING.yml4
-rw-r--r--.github/workflows/docs.yml22
-rw-r--r--.github/workflows/lint.yml48
-rw-r--r--.github/workflows/packages-pool.yml66
-rw-r--r--.github/workflows/packages.yml256
-rw-r--r--.github/workflows/tests.yml314
-rw-r--r--.gitignore23
-rw-r--r--.yamllint.yaml8
-rw-r--r--BACKERS.yaml140
-rw-r--r--LICENSE.txt165
-rw-r--r--README.rst75
-rw-r--r--docs/Makefile30
-rw-r--r--docs/_static/psycopg.css11
-rw-r--r--docs/_static/psycopg.svg1
-rw-r--r--docs/_templates/.keep0
-rw-r--r--docs/advanced/adapt.rst269
-rw-r--r--docs/advanced/async.rst360
-rw-r--r--docs/advanced/cursors.rst192
-rw-r--r--docs/advanced/index.rst21
-rw-r--r--docs/advanced/pipeline.rst324
-rw-r--r--docs/advanced/pool.rst332
-rw-r--r--docs/advanced/prepare.rst57
-rw-r--r--docs/advanced/rows.rst116
-rw-r--r--docs/advanced/typing.rst180
-rw-r--r--docs/api/abc.rst75
-rw-r--r--docs/api/adapt.rst91
-rw-r--r--docs/api/connections.rst489
-rw-r--r--docs/api/conninfo.rst24
-rw-r--r--docs/api/copy.rst117
-rw-r--r--docs/api/crdb.rst120
-rw-r--r--docs/api/cursors.rst517
-rw-r--r--docs/api/dns.rst145
-rw-r--r--docs/api/errors.rst540
-rw-r--r--docs/api/index.rst29
-rw-r--r--docs/api/module.rst59
-rw-r--r--docs/api/objects.rst256
-rw-r--r--docs/api/pool.rst331
-rw-r--r--docs/api/pq.rst218
-rw-r--r--docs/api/rows.rst74
-rw-r--r--docs/api/sql.rst151
-rw-r--r--docs/api/types.rst168
-rw-r--r--docs/basic/adapt.rst522
-rw-r--r--docs/basic/copy.rst212
-rw-r--r--docs/basic/from_pg2.rst359
-rw-r--r--docs/basic/index.rst26
-rw-r--r--docs/basic/install.rst172
-rw-r--r--docs/basic/params.rst242
-rw-r--r--docs/basic/pgtypes.rst389
-rw-r--r--docs/basic/transactions.rst388
-rw-r--r--docs/basic/usage.rst232
-rw-r--r--docs/conf.py110
-rw-r--r--docs/index.rst52
-rw-r--r--docs/lib/libpq_docs.py182
-rw-r--r--docs/lib/pg3_docs.py197
-rw-r--r--docs/lib/sql_role.py23
-rw-r--r--docs/lib/ticket_role.py50
-rw-r--r--docs/news.rst285
-rw-r--r--docs/news_pool.rst81
-rw-r--r--docs/pictures/adapt.drawio107
-rw-r--r--docs/pictures/adapt.svg3
-rw-r--r--docs/release.rst39
-rw-r--r--psycopg/.flake86
-rw-r--r--psycopg/LICENSE.txt165
-rw-r--r--psycopg/README.rst31
-rw-r--r--psycopg/psycopg/__init__.py110
-rw-r--r--psycopg/psycopg/_adapters_map.py289
-rw-r--r--psycopg/psycopg/_cmodule.py24
-rw-r--r--psycopg/psycopg/_column.py143
-rw-r--r--psycopg/psycopg/_compat.py72
-rw-r--r--psycopg/psycopg/_dns.py223
-rw-r--r--psycopg/psycopg/_encodings.py170
-rw-r--r--psycopg/psycopg/_enums.py79
-rw-r--r--psycopg/psycopg/_pipeline.py288
-rw-r--r--psycopg/psycopg/_preparing.py198
-rw-r--r--psycopg/psycopg/_queries.py375
-rw-r--r--psycopg/psycopg/_struct.py57
-rw-r--r--psycopg/psycopg/_tpc.py116
-rw-r--r--psycopg/psycopg/_transform.py350
-rw-r--r--psycopg/psycopg/_typeinfo.py461
-rw-r--r--psycopg/psycopg/_tz.py44
-rw-r--r--psycopg/psycopg/_wrappers.py137
-rw-r--r--psycopg/psycopg/abc.py266
-rw-r--r--psycopg/psycopg/adapt.py162
-rw-r--r--psycopg/psycopg/client_cursor.py95
-rw-r--r--psycopg/psycopg/connection.py1031
-rw-r--r--psycopg/psycopg/connection_async.py436
-rw-r--r--psycopg/psycopg/conninfo.py378
-rw-r--r--psycopg/psycopg/copy.py904
-rw-r--r--psycopg/psycopg/crdb/__init__.py19
-rw-r--r--psycopg/psycopg/crdb/_types.py163
-rw-r--r--psycopg/psycopg/crdb/connection.py186
-rw-r--r--psycopg/psycopg/cursor.py921
-rw-r--r--psycopg/psycopg/cursor_async.py250
-rw-r--r--psycopg/psycopg/dbapi20.py112
-rw-r--r--psycopg/psycopg/errors.py1535
-rw-r--r--psycopg/psycopg/generators.py320
-rw-r--r--psycopg/psycopg/postgres.py125
-rw-r--r--psycopg/psycopg/pq/__init__.py133
-rw-r--r--psycopg/psycopg/pq/_debug.py106
-rw-r--r--psycopg/psycopg/pq/_enums.py249
-rw-r--r--psycopg/psycopg/pq/_pq_ctypes.py804
-rw-r--r--psycopg/psycopg/pq/_pq_ctypes.pyi216
-rw-r--r--psycopg/psycopg/pq/abc.py385
-rw-r--r--psycopg/psycopg/pq/misc.py146
-rw-r--r--psycopg/psycopg/pq/pq_ctypes.py1086
-rw-r--r--psycopg/psycopg/py.typed0
-rw-r--r--psycopg/psycopg/rows.py256
-rw-r--r--psycopg/psycopg/server_cursor.py479
-rw-r--r--psycopg/psycopg/sql.py467
-rw-r--r--psycopg/psycopg/transaction.py290
-rw-r--r--psycopg/psycopg/types/__init__.py11
-rw-r--r--psycopg/psycopg/types/array.py464
-rw-r--r--psycopg/psycopg/types/bool.py51
-rw-r--r--psycopg/psycopg/types/composite.py290
-rw-r--r--psycopg/psycopg/types/datetime.py754
-rw-r--r--psycopg/psycopg/types/enum.py177
-rw-r--r--psycopg/psycopg/types/hstore.py131
-rw-r--r--psycopg/psycopg/types/json.py232
-rw-r--r--psycopg/psycopg/types/multirange.py514
-rw-r--r--psycopg/psycopg/types/net.py206
-rw-r--r--psycopg/psycopg/types/none.py25
-rw-r--r--psycopg/psycopg/types/numeric.py515
-rw-r--r--psycopg/psycopg/types/range.py700
-rw-r--r--psycopg/psycopg/types/shapely.py75
-rw-r--r--psycopg/psycopg/types/string.py239
-rw-r--r--psycopg/psycopg/types/uuid.py65
-rw-r--r--psycopg/psycopg/version.py14
-rw-r--r--psycopg/psycopg/waiting.py331
-rw-r--r--psycopg/pyproject.toml3
-rw-r--r--psycopg/setup.cfg47
-rw-r--r--psycopg/setup.py66
-rw-r--r--psycopg_c/.flake83
-rw-r--r--psycopg_c/LICENSE.txt165
-rw-r--r--psycopg_c/README-binary.rst29
-rw-r--r--psycopg_c/README.rst33
-rw-r--r--psycopg_c/psycopg_c/.gitignore4
-rw-r--r--psycopg_c/psycopg_c/__init__.py14
-rw-r--r--psycopg_c/psycopg_c/_psycopg.pyi84
-rw-r--r--psycopg_c/psycopg_c/_psycopg.pyx48
-rw-r--r--psycopg_c/psycopg_c/_psycopg/__init__.pxd9
-rw-r--r--psycopg_c/psycopg_c/_psycopg/adapt.pyx171
-rw-r--r--psycopg_c/psycopg_c/_psycopg/copy.pyx340
-rw-r--r--psycopg_c/psycopg_c/_psycopg/endian.pxd155
-rw-r--r--psycopg_c/psycopg_c/_psycopg/generators.pyx276
-rw-r--r--psycopg_c/psycopg_c/_psycopg/oids.pxd92
-rw-r--r--psycopg_c/psycopg_c/_psycopg/transform.pyx640
-rw-r--r--psycopg_c/psycopg_c/_psycopg/waiting.pyx197
-rw-r--r--psycopg_c/psycopg_c/pq.pxd78
-rw-r--r--psycopg_c/psycopg_c/pq.pyx38
-rw-r--r--psycopg_c/psycopg_c/pq/__init__.pxd9
-rw-r--r--psycopg_c/psycopg_c/pq/conninfo.pyx61
-rw-r--r--psycopg_c/psycopg_c/pq/escaping.pyx132
-rw-r--r--psycopg_c/psycopg_c/pq/libpq.pxd321
-rw-r--r--psycopg_c/psycopg_c/pq/pgcancel.pyx32
-rw-r--r--psycopg_c/psycopg_c/pq/pgconn.pyx733
-rw-r--r--psycopg_c/psycopg_c/pq/pgresult.pyx157
-rw-r--r--psycopg_c/psycopg_c/pq/pqbuffer.pyx111
-rw-r--r--psycopg_c/psycopg_c/py.typed0
-rw-r--r--psycopg_c/psycopg_c/types/array.pyx276
-rw-r--r--psycopg_c/psycopg_c/types/bool.pyx78
-rw-r--r--psycopg_c/psycopg_c/types/datetime.pyx1136
-rw-r--r--psycopg_c/psycopg_c/types/numeric.pyx715
-rw-r--r--psycopg_c/psycopg_c/types/numutils.c243
-rw-r--r--psycopg_c/psycopg_c/types/string.pyx315
-rw-r--r--psycopg_c/psycopg_c/version.py11
-rw-r--r--psycopg_c/pyproject.toml3
-rw-r--r--psycopg_c/setup.cfg57
-rw-r--r--psycopg_c/setup.py110
-rw-r--r--psycopg_pool/.flake83
-rw-r--r--psycopg_pool/LICENSE.txt165
-rw-r--r--psycopg_pool/README.rst24
-rw-r--r--psycopg_pool/psycopg_pool/__init__.py22
-rw-r--r--psycopg_pool/psycopg_pool/_compat.py51
-rw-r--r--psycopg_pool/psycopg_pool/base.py230
-rw-r--r--psycopg_pool/psycopg_pool/errors.py25
-rw-r--r--psycopg_pool/psycopg_pool/null_pool.py159
-rw-r--r--psycopg_pool/psycopg_pool/null_pool_async.py122
-rw-r--r--psycopg_pool/psycopg_pool/pool.py839
-rw-r--r--psycopg_pool/psycopg_pool/pool_async.py784
-rw-r--r--psycopg_pool/psycopg_pool/py.typed0
-rw-r--r--psycopg_pool/psycopg_pool/sched.py177
-rw-r--r--psycopg_pool/psycopg_pool/version.py13
-rw-r--r--psycopg_pool/setup.cfg45
-rw-r--r--psycopg_pool/setup.py26
-rw-r--r--pyproject.toml55
-rw-r--r--tests/README.rst94
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/adapters_example.py45
-rw-r--r--tests/conftest.py92
-rw-r--r--tests/constraints.txt32
-rw-r--r--tests/crdb/__init__.py0
-rw-r--r--tests/crdb/test_adapt.py78
-rw-r--r--tests/crdb/test_connection.py86
-rw-r--r--tests/crdb/test_connection_async.py85
-rw-r--r--tests/crdb/test_conninfo.py21
-rw-r--r--tests/crdb/test_copy.py233
-rw-r--r--tests/crdb/test_copy_async.py235
-rw-r--r--tests/crdb/test_cursor.py65
-rw-r--r--tests/crdb/test_cursor_async.py61
-rw-r--r--tests/crdb/test_no_crdb.py34
-rw-r--r--tests/crdb/test_typing.py49
-rw-r--r--tests/dbapi20.py870
-rw-r--r--tests/dbapi20_tpc.py151
-rw-r--r--tests/fix_crdb.py131
-rw-r--r--tests/fix_db.py358
-rw-r--r--tests/fix_faker.py868
-rw-r--r--tests/fix_mypy.py54
-rw-r--r--tests/fix_pq.py141
-rw-r--r--tests/fix_proxy.py127
-rw-r--r--tests/fix_psycopg.py98
-rw-r--r--tests/pool/__init__.py0
-rw-r--r--tests/pool/fix_pool.py12
-rw-r--r--tests/pool/test_null_pool.py896
-rw-r--r--tests/pool/test_null_pool_async.py844
-rw-r--r--tests/pool/test_pool.py1265
-rw-r--r--tests/pool/test_pool_async.py1198
-rw-r--r--tests/pool/test_pool_async_noasyncio.py78
-rw-r--r--tests/pool/test_sched.py154
-rw-r--r--tests/pool/test_sched_async.py159
-rw-r--r--tests/pq/__init__.py0
-rw-r--r--tests/pq/test_async.py210
-rw-r--r--tests/pq/test_conninfo.py48
-rw-r--r--tests/pq/test_copy.py174
-rw-r--r--tests/pq/test_escaping.py188
-rw-r--r--tests/pq/test_exec.py146
-rw-r--r--tests/pq/test_misc.py83
-rw-r--r--tests/pq/test_pgconn.py585
-rw-r--r--tests/pq/test_pgresult.py207
-rw-r--r--tests/pq/test_pipeline.py161
-rw-r--r--tests/pq/test_pq.py57
-rw-r--r--tests/scripts/bench-411.py300
-rw-r--r--tests/scripts/dectest.py51
-rw-r--r--tests/scripts/pipeline-demo.py340
-rw-r--r--tests/scripts/spiketest.py156
-rw-r--r--tests/test_adapt.py530
-rw-r--r--tests/test_client_cursor.py855
-rw-r--r--tests/test_client_cursor_async.py727
-rw-r--r--tests/test_concurrency.py327
-rw-r--r--tests/test_concurrency_async.py242
-rw-r--r--tests/test_connection.py790
-rw-r--r--tests/test_connection_async.py751
-rw-r--r--tests/test_conninfo.py450
-rw-r--r--tests/test_copy.py889
-rw-r--r--tests/test_copy_async.py892
-rw-r--r--tests/test_cursor.py942
-rw-r--r--tests/test_cursor_async.py802
-rw-r--r--tests/test_dns.py27
-rw-r--r--tests/test_dns_srv.py149
-rw-r--r--tests/test_encodings.py57
-rw-r--r--tests/test_errors.py309
-rw-r--r--tests/test_generators.py156
-rw-r--r--tests/test_module.py57
-rw-r--r--tests/test_pipeline.py577
-rw-r--r--tests/test_pipeline_async.py586
-rw-r--r--tests/test_prepared.py277
-rw-r--r--tests/test_prepared_async.py207
-rw-r--r--tests/test_psycopg_dbapi20.py164
-rw-r--r--tests/test_query.py162
-rw-r--r--tests/test_rows.py167
-rw-r--r--tests/test_server_cursor.py525
-rw-r--r--tests/test_server_cursor_async.py543
-rw-r--r--tests/test_sql.py604
-rw-r--r--tests/test_tpc.py325
-rw-r--r--tests/test_tpc_async.py310
-rw-r--r--tests/test_transaction.py796
-rw-r--r--tests/test_transaction_async.py743
-rw-r--r--tests/test_typeinfo.py145
-rw-r--r--tests/test_typing.py449
-rw-r--r--tests/test_waiting.py159
-rw-r--r--tests/test_windows.py23
-rw-r--r--tests/types/__init__.py0
-rw-r--r--tests/types/test_array.py338
-rw-r--r--tests/types/test_bool.py47
-rw-r--r--tests/types/test_composite.py396
-rw-r--r--tests/types/test_datetime.py813
-rw-r--r--tests/types/test_enum.py363
-rw-r--r--tests/types/test_hstore.py107
-rw-r--r--tests/types/test_json.py182
-rw-r--r--tests/types/test_multirange.py434
-rw-r--r--tests/types/test_net.py135
-rw-r--r--tests/types/test_none.py12
-rw-r--r--tests/types/test_numeric.py625
-rw-r--r--tests/types/test_range.py677
-rw-r--r--tests/types/test_shapely.py152
-rw-r--r--tests/types/test_string.py307
-rw-r--r--tests/types/test_uuid.py56
-rw-r--r--tests/typing_example.py176
-rw-r--r--tests/utils.py179
-rwxr-xr-xtools/build/build_libpq.sh173
-rwxr-xr-xtools/build/build_macos_arm64.sh93
-rwxr-xr-xtools/build/ci_test.sh29
-rwxr-xr-xtools/build/copy_to_binary.py39
-rwxr-xr-xtools/build/print_so_versions.sh37
-rwxr-xr-xtools/build/run_build_macos_arm64.sh40
-rwxr-xr-xtools/build/strip_wheel.sh48
-rwxr-xr-xtools/build/wheel_linux_before_all.sh48
-rwxr-xr-xtools/build/wheel_macos_before_all.sh28
-rw-r--r--tools/build/wheel_win32_before_build.bat3
-rwxr-xr-xtools/bump_version.py310
-rwxr-xr-xtools/update_backer.py134
-rwxr-xr-xtools/update_errors.py217
-rwxr-xr-xtools/update_oids.py217
304 files changed, 72932 insertions, 0 deletions
diff --git a/.codespellrc b/.codespellrc
new file mode 100644
index 0000000..33ec5a6
--- /dev/null
+++ b/.codespellrc
@@ -0,0 +1,3 @@
+[codespell]
+ignore-words-list = alot,ans,ba,fo,te
+skip = docs/_build,.tox,.mypy_cache,.venv,pq.c
diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..ec4053f
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,7 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
+extend-exclude = .venv build
+per-file-ignores =
+ # Autogenerated section
+ psycopg/psycopg/errors.py: E125, E128, E302
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 0000000..b648a1e
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,4 @@
+github:
+ - dvarrazzo
+custom:
+ - "https://www.paypal.me/dvarrazzo"
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
new file mode 100644
index 0000000..1dd1e94
--- /dev/null
+++ b/.github/workflows/docs.yml
@@ -0,0 +1,22 @@
+name: Build documentation
+
+on:
+ push:
+ branches:
+ # This should match the DOC3_BRANCH value in the psycopg-website Makefile
+ - master
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref_name }}
+ cancel-in-progress: true
+
+jobs:
+ docs:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Trigger docs build
+ uses: peter-evans/repository-dispatch@v1
+ with:
+ repository: psycopg/psycopg-website
+ event-type: psycopg3-commit
+ token: ${{ secrets.ACCESS_TOKEN }}
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
new file mode 100644
index 0000000..4527551
--- /dev/null
+++ b/.github/workflows/lint.yml
@@ -0,0 +1,48 @@
+name: Lint
+
+on:
+ push:
+ # This should disable running the workflow on tags, according to the
+ # on.<push|pull_request>.<branches|tags> GitHub Actions docs.
+ branches:
+ - "*"
+ pull_request:
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref_name }}
+ cancel-in-progress: true
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+
+ - name: install packages to tests
+ run: pip install ./psycopg[dev,test] codespell
+
+ - name: Run black
+ run: black --check --diff .
+
+ - name: Run flake8
+ run: flake8
+
+ - name: Run mypy
+ run: mypy
+
+ - name: Check spelling
+ run: codespell
+
+ - name: Install requirements to generate docs
+ run: sudo apt-get install -y libgeos-dev
+
+ - name: Install Python packages to generate docs
+ run: pip install ./psycopg[docs] ./psycopg_pool
+
+ - name: Check documentation
+ run: sphinx-build -W -T -b html docs docs/_build/html
diff --git a/.github/workflows/packages-pool.yml b/.github/workflows/packages-pool.yml
new file mode 100644
index 0000000..e9624e7
--- /dev/null
+++ b/.github/workflows/packages-pool.yml
@@ -0,0 +1,66 @@
+name: Build pool packages
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '28 6 * * sun'
+
+jobs:
+
+ sdist:
+ runs-on: ubuntu-latest
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - {package: psycopg_pool, format: sdist, impl: python}
+ - {package: psycopg_pool, format: wheel, impl: python}
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: 3.9
+
+ - name: Create the sdist packages
+ run: |-
+ python ${{ matrix.package }}/setup.py sdist -d `pwd`/dist/
+ if: ${{ matrix.format == 'sdist' }}
+
+ - name: Create the wheel packages
+ run: |-
+ pip install wheel
+ python ${{ matrix.package }}/setup.py bdist_wheel -d `pwd`/dist/
+ if: ${{ matrix.format == 'wheel' }}
+
+ - name: Install the Python pool package and test requirements
+ run: |-
+ pip install dist/*
+ pip install ./psycopg[test]
+
+ - name: Test the sdist package
+ run: pytest -m 'not slow and not flakey' --color yes
+ env:
+ PSYCOPG_IMPL: ${{ matrix.impl }}
+ PSYCOPG_TEST_DSN: "host=127.0.0.1 user=postgres"
+ PGPASSWORD: password
+
+ - uses: actions/upload-artifact@v3
+ with:
+ path: ./dist/*
+
+ services:
+ postgresql:
+ image: postgres:14
+ env:
+ POSTGRES_PASSWORD: password
+ ports:
+ - 5432:5432
+ # Set health checks to wait until postgres has started
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
diff --git a/.github/workflows/packages.yml b/.github/workflows/packages.yml
new file mode 100644
index 0000000..18a2817
--- /dev/null
+++ b/.github/workflows/packages.yml
@@ -0,0 +1,256 @@
+name: Build packages
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '28 7 * * sun'
+
+jobs:
+
+ sdist: # {{{
+ runs-on: ubuntu-latest
+ if: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - {package: psycopg, format: sdist, impl: python}
+ - {package: psycopg, format: wheel, impl: python}
+ - {package: psycopg_c, format: sdist, impl: c}
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: 3.9
+
+ - name: Create the sdist packages
+ run: |-
+ python ${{ matrix.package }}/setup.py sdist -d `pwd`/dist/
+ if: ${{ matrix.format == 'sdist' }}
+
+ - name: Create the wheel packages
+ run: |-
+ pip install wheel
+ python ${{ matrix.package }}/setup.py bdist_wheel -d `pwd`/dist/
+ if: ${{ matrix.format == 'wheel' }}
+
+ - name: Install the Python package and test requirements
+ run: |-
+ pip install `ls dist/*`[test]
+ pip install ./psycopg_pool
+ if: ${{ matrix.package == 'psycopg' }}
+
+ - name: Install the C package and test requirements
+ run: |-
+ pip install dist/*
+ pip install ./psycopg[test]
+ pip install ./psycopg_pool
+ if: ${{ matrix.package == 'psycopg_c' }}
+
+ - name: Test the sdist package
+ run: pytest -m 'not slow and not flakey' --color yes
+ env:
+ PSYCOPG_IMPL: ${{ matrix.impl }}
+ PSYCOPG_TEST_DSN: "host=127.0.0.1 user=postgres"
+ PGPASSWORD: password
+
+ - uses: actions/upload-artifact@v3
+ with:
+ path: ./dist/*
+
+ services:
+ postgresql:
+ image: postgres:14
+ env:
+ POSTGRES_PASSWORD: password
+ ports:
+ - 5432:5432
+ # Set health checks to wait until postgres has started
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+
+ # }}}
+
+ linux: # {{{
+ runs-on: ubuntu-latest
+ if: true
+
+ env:
+ LIBPQ_VERSION: "15.1"
+ OPENSSL_VERSION: "1.1.1s"
+
+ strategy:
+ fail-fast: false
+ matrix:
+ arch: [x86_64, i686, ppc64le, aarch64]
+ pyver: [cp37, cp38, cp39, cp310, cp311]
+ platform: [manylinux, musllinux]
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up QEMU for multi-arch build
+ # Check https://github.com/docker/setup-qemu-action for newer versions.
+ uses: docker/setup-qemu-action@v2
+ with:
+ # Note: 6.2.0 is buggy: make sure to avoid it.
+ # See https://github.com/pypa/cibuildwheel/issues/1250
+ image: tonistiigi/binfmt:qemu-v7.0.0
+
+ - name: Cache libpq build
+ uses: actions/cache@v3
+ with:
+ path: /tmp/libpq.build
+ key: libpq-${{ env.LIBPQ_VERSION }}-${{ matrix.platform }}-${{ matrix.arch }}-2
+
+ - name: Create the binary package source tree
+ run: python3 ./tools/build/copy_to_binary.py
+
+ - name: Build wheels
+ uses: pypa/cibuildwheel@v2.9.0
+ with:
+ package-dir: psycopg_binary
+ env:
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014
+ CIBW_MANYLINUX_I686_IMAGE: manylinux2014
+ CIBW_MANYLINUX_AARCH64_IMAGE: manylinux2014
+ CIBW_MANYLINUX_PPC64LE_IMAGE: manylinux2014
+ CIBW_BUILD: ${{matrix.pyver}}-${{matrix.platform}}_${{matrix.arch}}
+ CIBW_ARCHS_LINUX: auto aarch64 ppc64le
+ CIBW_BEFORE_ALL_LINUX: ./tools/build/wheel_linux_before_all.sh
+ CIBW_REPAIR_WHEEL_COMMAND: >-
+ ./tools/build/strip_wheel.sh {wheel}
+ && auditwheel repair -w {dest_dir} {wheel}
+ CIBW_TEST_REQUIRES: ./psycopg[test] ./psycopg_pool
+ CIBW_TEST_COMMAND: >-
+ pytest {project}/tests -m 'not slow and not flakey' --color yes
+ CIBW_ENVIRONMENT_PASS_LINUX: LIBPQ_VERSION OPENSSL_VERSION
+ CIBW_ENVIRONMENT: >-
+ PSYCOPG_IMPL=binary
+ PSYCOPG_TEST_DSN='host=172.17.0.1 user=postgres'
+ PGPASSWORD=password
+ LIBPQ_BUILD_PREFIX=/host/tmp/libpq.build
+ PATH="$LIBPQ_BUILD_PREFIX/bin:$PATH"
+ LD_LIBRARY_PATH="$LIBPQ_BUILD_PREFIX/lib"
+ PSYCOPG_TEST_WANT_LIBPQ_BUILD=${{ env.LIBPQ_VERSION }}
+ PSYCOPG_TEST_WANT_LIBPQ_IMPORT=${{ env.LIBPQ_VERSION }}
+
+ - uses: actions/upload-artifact@v3
+ with:
+ path: ./wheelhouse/*.whl
+
+ services:
+ postgresql:
+ image: postgres:14
+ env:
+ POSTGRES_PASSWORD: password
+ ports:
+ - 5432:5432
+ # Set health checks to wait until postgres has started
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+
+ # }}}
+
+ macos: # {{{
+ runs-on: macos-latest
+ if: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ # These archs require an Apple M1 runner: [arm64, universal2]
+ arch: [x86_64]
+ pyver: [cp37, cp38, cp39, cp310, cp311]
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Create the binary package source tree
+ run: python3 ./tools/build/copy_to_binary.py
+
+ - name: Build wheels
+ uses: pypa/cibuildwheel@v2.9.0
+ with:
+ package-dir: psycopg_binary
+ env:
+ CIBW_BUILD: ${{matrix.pyver}}-macosx_${{matrix.arch}}
+ CIBW_ARCHS_MACOS: x86_64
+ CIBW_BEFORE_ALL_MACOS: ./tools/build/wheel_macos_before_all.sh
+ CIBW_TEST_REQUIRES: ./psycopg[test] ./psycopg_pool
+ CIBW_TEST_COMMAND: >-
+ pytest {project}/tests -m 'not slow and not flakey' --color yes
+ CIBW_ENVIRONMENT: >-
+ PSYCOPG_IMPL=binary
+ PSYCOPG_TEST_DSN='dbname=postgres'
+ PSYCOPG_TEST_WANT_LIBPQ_BUILD=">= 14"
+ PSYCOPG_TEST_WANT_LIBPQ_IMPORT=">= 14"
+
+ - uses: actions/upload-artifact@v3
+ with:
+ path: ./wheelhouse/*.whl
+
+
+ # }}}
+
+ windows: # {{{
+ runs-on: windows-latest
+ if: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ # Might want to add win32, untested at the moment.
+ arch: [win_amd64]
+ pyver: [cp37, cp38, cp39, cp310, cp311]
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Start PostgreSQL service for test
+ run: |
+ $PgSvc = Get-Service "postgresql*"
+ Set-Service $PgSvc.Name -StartupType manual
+ $PgSvc.Start()
+
+ - name: Create the binary package source tree
+ run: python3 ./tools/build/copy_to_binary.py
+
+ - name: Build wheels
+ uses: pypa/cibuildwheel@v2.9.0
+ with:
+ package-dir: psycopg_binary
+ env:
+ CIBW_BUILD: ${{matrix.pyver}}-${{matrix.arch}}
+ CIBW_ARCHS_WINDOWS: AMD64 x86
+ CIBW_BEFORE_BUILD_WINDOWS: '.\tools\build\wheel_win32_before_build.bat'
+ CIBW_REPAIR_WHEEL_COMMAND_WINDOWS: >-
+ delvewheel repair -w {dest_dir}
+ --no-mangle "libiconv-2.dll;libwinpthread-1.dll" {wheel}
+ CIBW_TEST_REQUIRES: ./psycopg[test] ./psycopg_pool
+ CIBW_TEST_COMMAND: >-
+ pytest {project}/tests -m "not slow and not flakey" --color yes
+ CIBW_ENVIRONMENT_WINDOWS: >-
+ PSYCOPG_IMPL=binary
+ PATH="C:\\Program Files\\PostgreSQL\\14\\bin;$PATH"
+ PSYCOPG_TEST_DSN="host=127.0.0.1 user=postgres"
+ PSYCOPG_TEST_WANT_LIBPQ_BUILD=">= 14"
+ PSYCOPG_TEST_WANT_LIBPQ_IMPORT=">= 14"
+
+ - uses: actions/upload-artifact@v3
+ with:
+ path: ./wheelhouse/*.whl
+
+
+ # }}}
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
new file mode 100644
index 0000000..9f6a7f5
--- /dev/null
+++ b/.github/workflows/tests.yml
@@ -0,0 +1,314 @@
+name: Tests
+
+on:
+ push:
+ # This should disable running the workflow on tags, according to the
+ # on.<push|pull_request>.<branches|tags> GitHub Actions docs.
+ branches:
+ - "*"
+ pull_request:
+ schedule:
+ - cron: '48 6 * * *'
+
+concurrency:
+ # Cancel older requests of the same workflow in the same branch.
+ group: ${{ github.workflow }}-${{ github.ref_name }}
+ cancel-in-progress: true
+
+jobs:
+
+ linux: # {{{
+ runs-on: ubuntu-latest
+ if: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ # Test different combinations of Python, Postgres, libpq.
+ - {impl: python, python: "3.7", postgres: "postgres:10", libpq: newest}
+ - {impl: python, python: "3.8", postgres: "postgres:12"}
+ - {impl: python, python: "3.9", postgres: "postgres:13"}
+ - {impl: python, python: "3.10", postgres: "postgres:14"}
+ - {impl: python, python: "3.11", postgres: "postgres:15", libpq: oldest}
+
+ - {impl: c, python: "3.7", postgres: "postgres:15", libpq: newest}
+ - {impl: c, python: "3.8", postgres: "postgres:13"}
+ - {impl: c, python: "3.9", postgres: "postgres:14"}
+ - {impl: c, python: "3.10", postgres: "postgres:13", libpq: oldest}
+ - {impl: c, python: "3.11", postgres: "postgres:10", libpq: newest}
+
+ - {impl: python, python: "3.9", ext: dns, postgres: "postgres:14"}
+ - {impl: python, python: "3.9", ext: postgis, postgres: "postgis/postgis"}
+
+ env:
+ PSYCOPG_IMPL: ${{ matrix.impl }}
+ DEPS: ./psycopg[test] ./psycopg_pool
+ PSYCOPG_TEST_DSN: "host=127.0.0.1 user=postgres"
+ PGPASSWORD: password
+ MARKERS: ""
+
+ # Enable to run tests using the minimum version of dependencies.
+ # PIP_CONSTRAINT: ${{ github.workspace }}/tests/constraints.txt
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python }}
+
+ - name: Install the newest libpq version available
+ if: ${{ matrix.libpq == 'newest' }}
+ run: |
+ set -x
+
+ curl -sL https://www.postgresql.org/media/keys/ACCC4CF8.asc \
+ | gpg --dearmor \
+ | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg > /dev/null
+
+ # NOTE: in order to test with a preview release, add its number to
+ # the deb entry. For instance, to test on preview Postgres 16, use:
+ # "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main 16"
+ rel=$(lsb_release -c -s)
+ echo "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main" \
+ | sudo tee -a /etc/apt/sources.list.d/pgdg.list > /dev/null
+ sudo apt-get -qq update
+
+ pqver=$(apt-cache show libpq5 | grep ^Version: | head -1 \
+ | awk '{print $2}')
+ sudo apt-get -qq -y install "libpq-dev=${pqver}" "libpq5=${pqver}"
+
+ - name: Install the oldest libpq version available
+ if: ${{ matrix.libpq == 'oldest' }}
+ run: |
+ set -x
+ pqver=$(apt-cache show libpq5 | grep ^Version: | tail -1 \
+ | awk '{print $2}')
+ sudo apt-get -qq -y --allow-downgrades install \
+ "libpq-dev=${pqver}" "libpq5=${pqver}"
+
+ - if: ${{ matrix.ext == 'dns' }}
+ run: |
+ echo "DEPS=$DEPS dnspython" >> $GITHUB_ENV
+ echo "MARKERS=$MARKERS dns" >> $GITHUB_ENV
+
+ - if: ${{ matrix.ext == 'postgis' }}
+ run: |
+ echo "DEPS=$DEPS shapely" >> $GITHUB_ENV
+ echo "MARKERS=$MARKERS postgis" >> $GITHUB_ENV
+
+ - if: ${{ matrix.impl == 'c' }}
+ run: |
+ echo "DEPS=$DEPS ./psycopg_c" >> $GITHUB_ENV
+
+ - name: Install Python dependencies
+ run: pip install $DEPS
+
+ - name: Run tests
+ run: ./tools/build/ci_test.sh
+
+ services:
+ postgresql:
+ image: ${{ matrix.postgres }}
+ env:
+ POSTGRES_PASSWORD: password
+ ports:
+ - 5432:5432
+ # Set health checks to wait until postgres has started
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+
+ # }}}
+
+ macos: # {{{
+ runs-on: macos-latest
+ if: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - {impl: python, python: "3.7"}
+ - {impl: python, python: "3.8"}
+ - {impl: python, python: "3.9"}
+ - {impl: python, python: "3.10"}
+ - {impl: python, python: "3.11"}
+ - {impl: c, python: "3.7"}
+ - {impl: c, python: "3.8"}
+ - {impl: c, python: "3.9"}
+ - {impl: c, python: "3.10"}
+ - {impl: c, python: "3.11"}
+
+ env:
+ PSYCOPG_IMPL: ${{ matrix.impl }}
+ DEPS: ./psycopg[test] ./psycopg_pool
+ PSYCOPG_TEST_DSN: "host=127.0.0.1 user=runner dbname=postgres"
+ # MacOS on GitHub Actions seems particularly slow.
+ # Don't run timing-based tests as they regularly fail.
+ # pproxy-based tests fail too, with the proxy not coming up in 2s.
+ NOT_MARKERS: "timing proxy mypy"
+ # PIP_CONSTRAINT: ${{ github.workspace }}/tests/constraints.txt
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Install PostgreSQL on the runner
+ run: brew install postgresql@14
+
+ - name: Start PostgreSQL service for test
+ run: brew services start postgresql
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python }}
+
+ - if: ${{ matrix.impl == 'c' }}
+ # skip tests failing on importing psycopg_c.pq on subprocess
+ # they only fail on Travis, work ok locally under tox too.
+ # TODO: check the same on GitHub Actions
+ run: |
+ echo "DEPS=$DEPS ./psycopg_c" >> $GITHUB_ENV
+
+ - name: Install Python dependencies
+ run: pip install $DEPS
+
+ - name: Run tests
+ run: ./tools/build/ci_test.sh
+
+
+ # }}}
+
+ windows: # {{{
+ runs-on: windows-latest
+ if: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - {impl: python, python: "3.7"}
+ - {impl: python, python: "3.8"}
+ - {impl: python, python: "3.9"}
+ - {impl: python, python: "3.10"}
+ - {impl: python, python: "3.11"}
+ - {impl: c, python: "3.7"}
+ - {impl: c, python: "3.8"}
+ - {impl: c, python: "3.9"}
+ - {impl: c, python: "3.10"}
+ - {impl: c, python: "3.11"}
+
+ env:
+ PSYCOPG_IMPL: ${{ matrix.impl }}
+ DEPS: ./psycopg[test] ./psycopg_pool
+ PSYCOPG_TEST_DSN: "host=127.0.0.1 dbname=postgres"
+ # On windows pproxy doesn't seem very happy. Also a few timing test fail.
+ NOT_MARKERS: "timing proxy mypy"
+ # PIP_CONSTRAINT: ${{ github.workspace }}/tests/constraints.txt
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Start PostgreSQL service for test
+ run: |
+ $PgSvc = Get-Service "postgresql*"
+ Set-Service $PgSvc.Name -StartupType manual
+ $PgSvc.Start()
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python }}
+
+ # Build a wheel package of the C extensions.
+ # If the wheel is not delocated, import fails with some dll not found
+ # (but it won't tell which one).
+ - name: Build the C wheel
+ if: ${{ matrix.impl == 'c' }}
+ run: |
+ pip install delvewheel wheel
+ $env:Path = "C:\Program Files\PostgreSQL\14\bin\;$env:Path"
+ python ./psycopg_c/setup.py bdist_wheel
+ &"delvewheel" repair `
+ --no-mangle "libiconv-2.dll;libwinpthread-1.dll" `
+ @(Get-ChildItem psycopg_c\dist\*.whl)
+ &"pip" install @(Get-ChildItem wheelhouse\*.whl)
+
+ - name: Run tests
+ run: |
+ pip install $DEPS
+ ./tools/build/ci_test.sh
+ shell: bash
+
+
+ # }}}
+
+ crdb: # {{{
+ runs-on: ubuntu-latest
+ if: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - {impl: c, crdb: "latest-v22.1", python: "3.10", libpq: newest}
+ - {impl: python, crdb: "latest-v22.2", python: "3.11"}
+ env:
+ PSYCOPG_IMPL: ${{ matrix.impl }}
+ DEPS: ./psycopg[test] ./psycopg_pool
+ PSYCOPG_TEST_DSN: "host=127.0.0.1 port=26257 user=root dbname=defaultdb"
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python }}
+
+ - name: Run CockroachDB
+ # Note: this would love to be a service, but I don't see a way to pass
+ # the args to the docker run command line.
+ run: |
+ docker pull cockroachdb/cockroach:${{ matrix.crdb }}
+ docker run --rm -d --name crdb -p 26257:26257 \
+ cockroachdb/cockroach:${{ matrix.crdb }} start-single-node --insecure
+
+ - name: Install the newest libpq version available
+ if: ${{ matrix.libpq == 'newest' }}
+ run: |
+ set -x
+
+ curl -sL https://www.postgresql.org/media/keys/ACCC4CF8.asc \
+ | gpg --dearmor \
+ | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg > /dev/null
+
+ # NOTE: in order to test with a preview release, add its number to
+ # the deb entry. For instance, to test on preview Postgres 16, use:
+ # "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main 16"
+ rel=$(lsb_release -c -s)
+ echo "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main" \
+ | sudo tee -a /etc/apt/sources.list.d/pgdg.list > /dev/null
+ sudo apt-get -qq update
+
+ pqver=$(apt-cache show libpq5 | grep ^Version: | head -1 \
+ | awk '{print $2}')
+ sudo apt-get -qq -y install "libpq-dev=${pqver}" "libpq5=${pqver}"
+
+ - if: ${{ matrix.impl == 'c' }}
+ run: |
+ echo "DEPS=$DEPS ./psycopg_c" >> $GITHUB_ENV
+
+ - name: Install Python dependencies
+ run: pip install $DEPS
+
+ - name: Run tests
+ run: ./tools/build/ci_test.sh
+
+ - name: Stop CockroachDB
+ run: docker kill crdb
+
+
+ # }}}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2d8c58c
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,23 @@
+*.egg-info/
+.tox
+*.pstats
+*.swp
+.mypy_cache
+__pycache__/
+/docs/_build/
+*.html
+/psycopg_binary/
+.vscode
+.venv
+.coverage
+htmlcov
+
+.eggs/
+dist/
+wheelhouse/
+# Spelling these explicitly because we have /scripts/build/ to not ignore
+# but I still want 'ag' to avoid looking here.
+/build/
+/psycopg/build/
+/psycopg_c/build/
+/psycopg_pool/build/
diff --git a/.yamllint.yaml b/.yamllint.yaml
new file mode 100644
index 0000000..bcfb4d2
--- /dev/null
+++ b/.yamllint.yaml
@@ -0,0 +1,8 @@
+extends: default
+
+rules:
+ truthy:
+ check-keys: false
+ document-start: disable
+ line-length:
+ max: 85
diff --git a/BACKERS.yaml b/BACKERS.yaml
new file mode 100644
index 0000000..b8bf830
--- /dev/null
+++ b/BACKERS.yaml
@@ -0,0 +1,140 @@
+---
+# You can find our sponsors at https://www.psycopg.org/sponsors/ Thank you!
+
+- username: postgrespro
+ tier: top
+ avatar: https://avatars.githubusercontent.com/u/12005770?v=4
+ name: Postgres Professional
+ website: https://postgrespro.com/
+
+- username: commandprompt
+ by: jdatcmd
+ tier: top
+ avatar: https://avatars.githubusercontent.com/u/339156?v=4
+ name: Command Prompt, Inc.
+ website: https://www.commandprompt.com
+
+- username: bitdotioinc
+ tier: top
+ avatar: https://avatars.githubusercontent.com/u/56135630?v=4
+ name: bit.io
+ website: https://bit.io/?utm_campaign=sponsorship&utm_source=psycopg2&utm_medium=web
+ keep_website: true
+
+
+- username: yougov
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/378494?v=4
+ name: YouGov
+ website: https://www.yougov.com
+
+- username: phenopolis
+ by: pontikos
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/20042742?v=4
+ name: Phenopolis
+ website: http://www.phenopolis.co.uk
+
+- username: MaterializeInc
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/47674186?v=4
+ name: Materialize, Inc.
+ website: http://materialize.com
+
+- username: getsentry
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/1396951?v=4
+ name: Sentry
+ website: https://sentry.io
+
+- username: 20tab
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/1843159?v=4
+ name: 20tab srl
+ website: http://www.20tab.com
+
+- username: genropy
+ by: gporcari
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/7373189?v=4
+ name: genropy
+ website: http://www.genropy.org
+
+- username: svennek
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/37837?v=4
+ name: Svenne Krap
+ website: http://www.svenne.dk
+
+- username: mailupinc
+ tier: mid
+ avatar: https://avatars.githubusercontent.com/u/72260631?v=4
+ name: BEE
+ website: https://beefree.io
+
+
+- username: taifu
+ avatar: https://avatars.githubusercontent.com/u/115712?v=4
+ name: Marco Beri
+ website: http:/beri.it
+
+- username: la-mar
+ avatar: https://avatars.githubusercontent.com/u/16618300?v=4
+ name: Brock Friedrich
+
+- username: xarg
+ avatar: https://avatars.githubusercontent.com/u/94721?v=4
+ name: Alex Plugaru
+ website: https://plugaru.org
+
+- username: dalibo
+ avatar: https://avatars.githubusercontent.com/u/182275?v=4
+ name: Dalibo
+ website: http://www.dalibo.com
+
+- username: rafmagns-skepa-dreag
+ avatar: https://avatars.githubusercontent.com/u/7447491?v=4
+ name: Richard H
+
+- username: rustprooflabs
+ avatar: https://avatars.githubusercontent.com/u/3085224?v=4
+ name: Ryan Lambert
+ website: https://www.rustprooflabs.com
+
+- username: logilab
+ avatar: https://avatars.githubusercontent.com/u/446566?v=4
+ name: Logilab
+ website: http://www.logilab.org
+
+- username: asqui
+ avatar: https://avatars.githubusercontent.com/u/174182?v=4
+ name: Daniel Fortunov
+
+- username: iqbalabd
+ avatar: https://avatars.githubusercontent.com/u/14254614?v=4
+ name: Iqbal Abdullah
+ website: https://info.xoxzo.com/
+
+- username: healthchecks
+ avatar: https://avatars.githubusercontent.com/u/13053880?v=4
+ name: Healthchecks
+ website: https://healthchecks.io
+
+- username: c-rindi
+ avatar: https://avatars.githubusercontent.com/u/7826876?v=4
+ name: C~+
+
+- username: Intevation
+ by: bernhardreiter
+ avatar: https://avatars.githubusercontent.com/u/2050405?v=4
+ name: Intevation
+ website: https://www.intevation.de/
+
+- username: abegerho
+ avatar: https://avatars.githubusercontent.com/u/5734243?v=4
+ name: Abhishek Begerhotta
+
+- username: ferpection
+ avatar: https://avatars.githubusercontent.com/u/6997008?v=4
+ name: Ferpection
+ website: https://ferpection.com
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/README.rst b/README.rst
new file mode 100644
index 0000000..50934d8
--- /dev/null
+++ b/README.rst
@@ -0,0 +1,75 @@
+Psycopg 3 -- PostgreSQL database adapter for Python
+===================================================
+
+Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python.
+
+
+Installation
+------------
+
+Quick version::
+
+ pip install --upgrade pip # upgrade pip to at least 20.3
+ pip install "psycopg[binary,pool]" # install binary dependencies
+
+For further information about installation please check `the documentation`__.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+
+
+Hacking
+-------
+
+In order to work on the Psycopg source code you need to have the ``libpq``
+PostgreSQL client library installed in the system. For instance, on Debian
+systems, you can obtain it by running::
+
+ sudo apt install libpq5
+
+After which you can clone this repository::
+
+ git clone https://github.com/psycopg/psycopg.git
+ cd psycopg
+
+Please note that the repository contains the source code of several Python
+packages: that's why you don't see a ``setup.py`` here. The packages may have
+different requirements:
+
+- The ``psycopg`` directory contains the pure python implementation of
+ ``psycopg``. The package has only a runtime dependency on the ``libpq``, the
+ PostgreSQL client library, which should be installed in your system.
+
+- The ``psycopg_c`` directory contains an optimization module written in
+ C/Cython. In order to build it you will need a few development tools: please
+ look at `Local installation`__ in the docs for the details.
+
+ .. __: https://www.psycopg.org/psycopg3/docs/basic/install.html#local-installation
+
+- The ``psycopg_pool`` directory contains the `connection pools`__
+ implementations. This is kept as a separate package to allow a different
+ release cycle.
+
+ .. __: https://www.psycopg.org/psycopg3/docs/advanced/pool.html
+
+You can create a local virtualenv and install there the packages `in
+development mode`__, together with their development and testing
+requirements::
+
+ python -m venv .venv
+ source .venv/bin/activate
+ pip install -e "./psycopg[dev,test]" # for the base Python package
+ pip install -e ./psycopg_pool # for the connection pool
+ pip install ./psycopg_c # for the C speedup module
+
+.. __: https://pip.pypa.io/en/stable/reference/pip_install/#install-editable
+
+Please add ``--config-settings editable_mode=strict`` to the ``pip install
+-e`` above if you experience `editable mode broken`__.
+
+.. __: https://github.com/pypa/setuptools/issues/3557
+
+Now hack away! You can run the tests using::
+
+ psql -c 'create database psycopg_test'
+ export PSYCOPG_TEST_DSN="dbname=psycopg_test"
+ pytest
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..e86cbd4
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,30 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+PYTHON ?= python3
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) || true
+
+serve:
+ PSYCOPG_IMPL=python sphinx-autobuild . _build/html/
+
+.PHONY: help serve env Makefile
+
+env: .venv
+
+.venv:
+ $(PYTHON) -m venv .venv
+ ./.venv/bin/pip install -e "../psycopg[docs]" -e ../psycopg_pool
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/_static/psycopg.css b/docs/_static/psycopg.css
new file mode 100644
index 0000000..de9d779
--- /dev/null
+++ b/docs/_static/psycopg.css
@@ -0,0 +1,11 @@
+/* style rubric in furo (too small IMO) */
+p.rubric {
+ font-size: 1.2rem;
+ font-weight: bold;
+}
+
+/* override a silly default */
+table.align-default td,
+table.align-default th {
+ text-align: left;
+}
diff --git a/docs/_static/psycopg.svg b/docs/_static/psycopg.svg
new file mode 100644
index 0000000..0e9ee32
--- /dev/null
+++ b/docs/_static/psycopg.svg
@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 228 148"><path fill="#ffc836" stroke="#000" stroke-width="7.415493" d="M142.3 67.6c.6-4.7.4-5.4 4-4.6h.8c2.7.2 6.2-.4 8.3-1.3 4.4-2.1 7-5.5 2.7-4.6-10 2-10.7-1.4-10.7-1.4 10.5-15.6 15-35.5 11.1-40.3-10.3-13.3-28.3-7-28.6-6.8h-.1c-2-.4-4.2-.7-6.7-.7-4.5 0-8 1.2-10.5 3.1 0 0-32-13.2-30.6 16.6.4 6.4 9.1 48 19.6 35.4 3.8-4.6 7.5-8.4 7.5-8.4 1.8 1.2 4 1.8 6.3 1.6l.2-.2a7 7 0 000 1.8c-2.6 3-1.8 3.5-7.2 4.7-5.4 1-2.2 3-.2 3.6 2.6.6 8.4 1.5 12.4-4l-.2.6c1.1.9 1 6.1 1.2 9.8.1 3.8.4 7.3 1.1 9.3.8 2 1.7 7.4 8.8 5.9 5.9-1.3 10.4-3.1 10.8-20.1"/><path fill="#ff0" d="M105.4 54.2a2.4 2.4 0 114.8 0c0 1.3-1.1 2.4-2.4 2.4-1.3 0-2.4-1-2.4-2.4z"/><g fill="#336791"><path stroke="#000" stroke-width="7.415493" d="M85.7 80.4c-.6 4.7-.4 5.4-4 4.6H81c-2.7-.2-6.2.4-8.3 1.3-4.4 2.1-7 5.5-2.7 4.6 10-2 10.7 1.4 10.7 1.4-10.5 15.6-15 35.5-11.1 40.3 10.3 13.3 28.3 7 28.6 6.8h.1c2 .4 4.2.7 6.7.7 4.5 0 8-1.2 10.5-3.1 0 0 32 13.2 30.6-16.6-.4-6.4-9.1-48-19.6-35.4-3.8 4.6-7.5 8.4-7.5 8.4a9.7 9.7 0 00-6.3-1.6l-.2.2a7 7 0 000-1.8c2.6-3 1.8-3.5 7.2-4.7 5.4-1 2.2-3 .2-3.6-2.6-.6-8.4-1.5-12.4 4l.2-.6c-1.1-.9-1-6.1-1.2-9.8-.1-3.8-.4-7.3-1.1-9.3-.8-2-1.7-7.4-8.8-5.9-5.9 1.3-10.4 3.1-10.8 20.1"/><path d="M70 91c10-2.1 10.6 1.3 10.6 1.3-10.5 15.6-15 35.5-11.1 40.3 10.3 13.3 28.3 7 28.6 6.8h.1c2 .4 4.2.7 6.7.7 4.5 0 8-1.2 10.5-3.1 0 0 32 13.2 30.6-16.6-.4-6.4-9.1-48-19.6-35.4-3.8 4.6-7.5 8.4-7.5 8.4a9.7 9.7 0 00-6.3-1.6l-.2.2a7 7 0 000-1.8c2.6-3 1.8-3.5 7.2-4.7 5.5-1 2.2-3 .2-3.6-2.6-.6-8.4-1.5-12.4 4l.2-.6c-1.1-.9-1.8-5.5-1.7-9.7.1-4.3.2-7.2-.6-9.4s-1.7-7.4-8.8-5.9c-5.9 1.3-9 4.6-9.4 10-.3 4-1 3.4-1 6.9l-.6 1.6c-.6 5.3 0 7-3.7 6.2h-.9c-2.7-.2-6.2.4-8.3 1.3-4.4 2.1-7 5.5-2.7 4.6z"/><g stroke="#ffc836" stroke-linecap="round" stroke-width="15.6"><g stroke-linejoin="round"><path stroke-width="2.477124" d="M107 88c.2-10-.1-19.8-1-22.2-1-2.4-3-7.1-10.2-5.6-5.9 1.3-8 3.7-9 9.1-.7 4-2 15.1-2.2 17.4M115.5 137.2s32 13 30.5-16.7c-.3-6.4-9-48-19.5-35.4-3.8 4.6-7.3 8.2-7.3 8.2M98.1 139.6c1.2-.4-17.8 6.9-28.6-6.9-3.8-4.8.6-24.7 11.2-40.3"/></g><path stroke-linejoin="bevel" stroke-width="2.477124" d="M80.7 92.4S80 89 70 91c-4.4.9-1.7-2.6 2.7-4.6 3.7-1.7 11.8-2.2 12 .2.3 6-4.3 4.2-4 5.7.3 1.3 2.4 2.7 3.7 6 1.2 2.9 16.5 25.2-4.2 21.9-.7.1 5.4 19.6 24.7 20 19.4.3 18.7-23.8 18.7-23.8"/><path stroke-linejoin="round" stroke-width="2.477124" d="M112.4 90.3c2.7-3 1.9-3.5 7.3-4.6 5.4-1.2 2.2-3.2.1-3.7-2.5-.6-8.4-1.5-12.3 4-1.2 1.7 0 4.4 1.6 5.1.8.3 2 .8 3.3-.8z"/><path stroke-linejoin="round" stroke-width="2.477124" d="M112.6 90.4c.2 1.7-.6 3.8-1.5 6.3-1.4 3.7-4.6 7.4-2 19.1 1.8 8.8 14.5 1.8 14.5.7 0-1.2-.5-6 .2-11.7 1-7.3-4.6-13.5-11.2-12.8"/></g></g><g stroke="#ffc836"><path fill="#ffc836" stroke-width=".825708" d="M115.6 116.6c0-.4-.8-1.4-1.8-1.6-1-.1-2 .7-2 1.1 0 .4.7.9 1.8 1 1 .2 2 0 2-.5z"/><path fill="#ffc836" stroke-width=".412854" d="M84 117.5c-.1-.4.7-1.5 1.7-1.7 1-.1 2 .7 2 1.1 0 .4-.7.9-1.8 1s-2 0-2-.4z"/><path fill="#336791" stroke-linecap="round" stroke-linejoin="round" stroke-width="2.477124" d="M80.2 120.3c-.2-3.2.7-5.4.8-8.7.2-5-2.3-10.6 1.4-16.2"/></g><g fill="#ffc836"><path d="M158 57c-10 2.1-10.6-1.3-10.6-1.3 10.5-15.6 15-35.5 11.1-40.3-10.3-13.3-28.3-7-28.6-6.8h-.1c-2-.4-4.2-.7-6.7-.7-4.5 0-8 1.2-10.5 3.1 0 0-32-13.2-30.6 16.6.4 6.4 9.1 48 19.6 35.4 3.8-4.6 7.5-8.5 7.5-8.5 1.8 1.3 4 1.9 6.3 1.7l.2-.2a7 7 0 000 1.8c-2.6 3-1.8 3.5-7.2 4.7-5.5 1-2.3 3-.2 3.6 2.6.6 8.4 1.5 12.4-4l-.2.6c1.1.9 1.8 5.5 1.7 9.7-.1 4.3-.2 7.2.6 9.4s1.7 7.4 8.8 5.9c5.9-1.3 9-4.6 9.4-10 .3-4 1-3.4 1-6.9l.6-1.6c.6-5.3 0-7 3.7-6.2h.9c2.7.2 6.2-.4 8.3-1.3 4.4-2.1 7-5.5 2.7-4.6z"/><path d="M142.3 67.6c.6-4.7.4-5.4 4-4.6h.8c2.7.2 6.2-.4 8.3-1.3 4.4-2.1 7-5.5 2.7-4.6-10 2-10.7-1.4-10.7-1.4 10.5-15.6 15-35.5 11.1-40.3-10.3-13.3-28.3-7-28.6-6.8h-.1c-2-.4-4.2-.7-6.7-.7-4.5 0-8 1.2-10.5 3.1 0 0-32-13.2-30.6 16.6.4 6.4 9.1 48 19.6 35.4 3.8-4.6 7.5-8.4 7.5-8.4 1.8 1.2 4 1.8 6.3 1.6l.2-.2a7 7 0 000 1.8c-2.6 3-1.8 3.5-7.2 4.7-5.4 1-2.2 3-.2 3.6 2.6.6 8.4 1.5 12.4-4l-.2.6c1.1.9 1 6.1 1.2 9.8.1 3.8.4 7.3 1.1 9.3.8 2 1.7 7.4 8.8 5.9 5.9-1.3 10.4-3.1 10.8-20.1"/><g stroke="#336791" stroke-linecap="round" stroke-width="15.6"><g stroke-linejoin="round"><path stroke-width="2.477124" d="M112.5 10.8s-32-13-30.5 16.7c.3 6.4 9 48 19.5 35.4 3.8-4.6 7.3-8.2 7.3-8.2M121 60c-.2 10 .1 19.8 1 22.2 1 2.4 3 7.1 10.2 5.6 5.9-1.3 8-3.7 9-9.2.7-4 2-15 2.1-17.3M129.9 8.4c-1.2.4 17.8-6.9 28.6 6.9 3.8 4.8-.6 24.7-11.2 40.3"/></g><path stroke-linejoin="bevel" stroke-width="2.477124" d="M147.3 55.6S148 59 158 57c4.4-.9 1.7 2.6-2.7 4.6-3.7 1.7-11.8 2.2-12-.2-.3-6 4.3-4.2 4-5.7-.3-1.3-2.4-2.7-3.7-6-1.2-3-16.5-25.2 4.2-21.9.7-.1-5.4-19.6-24.7-20-19.4-.3-18.7 23.8-18.7 23.8"/><path stroke-linejoin="round" stroke-width="2.477124" d="M115.6 57.7c-2.7 3-1.9 3.5-7.3 4.6-5.4 1.2-2.2 3.2-.1 3.7 2.5.6 8.4 1.5 12.3-4 1.2-1.7 0-4.4-1.6-5.1-.8-.3-2-.8-3.3.8z"/><path stroke-linejoin="round" stroke-width="2.477124" d="M115.4 57.6c-.2-1.7.6-3.8 1.5-6.3 1.4-3.7 4.6-7.4 2-19.1-1.8-8.8-14.5-1.9-14.5-.7s.5 6-.2 11.7c-1 7.3 4.6 13.5 11.2 12.8"/></g></g><g stroke="#336791"><path fill="#336791" stroke-width=".825708" d="M112.4 31.4c0 .4.8 1.4 1.8 1.6 1 .1 2-.7 2-1.1 0-.4-.7-.9-1.8-1-1-.2-2 0-2 .5z"/><path fill="#336791" stroke-width=".412854" d="M144 30.5c.1.4-.7 1.5-1.7 1.7-1 .1-2-.7-2-1.1 0-.4.7-.9 1.8-1s2 0 2 .4z"/><path fill="#ffc836" stroke-linecap="round" stroke-linejoin="round" stroke-width="2.477124" d="M147.8 27.7c.2 3.2-.7 5.4-.8 8.7-.2 5 2.3 10.6-1.4 16.2"/></g><path fill="#ffc836" stroke="#336791" stroke-linecap="round" stroke-linejoin="round" stroke-width=".6034019999999999" d="M103.8 51h6.6v6.4h-6.6z" color="#000"/><path fill="#336791" stroke="#ffc836" stroke-linecap="round" stroke-linejoin="round" stroke-width="2.477124" d="M107 88c.2-10-.1-19.8-1-22.2-1-2.4-3-7.1-10.2-5.6-5.9 1.3-8 3.7-9 9.1-.7 4-2 15.1-2.2 17.4"/><path fill="#336791" d="M111.7 82.9h22.1v14.4h-22.1z" color="#000"/><path fill="#ffc836" d="M95.8 56h20.1v9.2H95.8z" color="#000"/><g fill="none"><path stroke="#ffc836" stroke-width="1.5878999999999999" d="M113.7 47.6c-2.2 0-4.2.2-6 .5-5.3 1-6.3 3-6.3 6.5v4.8H114V61H96.7a7.8 7.8 0 00-7.8 6.3A23.4 23.4 0 0089 80c1 3.7 3 6.4 6.7 6.4h4.3v-5.7a8 8 0 017.8-7.8h12.5c3.5 0 6.2-2.9 6.2-6.4V54.6c0-3.4-2.8-5.9-6.2-6.5a39 39 0 00-6.5-.5z"/><path stroke="#336791" stroke-width="2.38185" d="M128 61v5.5a8 8 0 01-7.8 8h-12.5c-3.4 0-6.3 2.9-6.3 6.3v12c0 3.3 3 5.3 6.3 6.3a21 21 0 0012.5 0c3.1-1 6.2-2.8 6.2-6.4V88H114v-1.6h18.8c3.6 0 5-2.6 6.2-6.4 1.3-3.9 1.3-7.6 0-12.7-.9-3.6-2.6-6.3-6.2-6.3z"/><path stroke="#ffc836" stroke-width="2.38185" d="M113.7 47.6c-2.2 0-4.2.2-6 .5-5.3 1-6.3 3-6.3 6.5v4.8H114V61H96.7a7.8 7.8 0 00-7.8 6.3A23.4 23.4 0 0089 80c1 3.7 3 6.4 6.7 6.4h4.3v-5.7a8 8 0 017.8-7.8h12.5c3.5 0 6.2-2.9 6.2-6.4V54.6c0-3.4-2.8-5.9-6.2-6.5a39 39 0 00-6.5-.5z"/><path stroke="#336791" stroke-width="1.5878999999999999" d="M128 61v5.5a8 8 0 01-7.8 8h-12.5c-3.4 0-6.3 2.9-6.3 6.3v12c0 3.3 3 5.3 6.3 6.3a21 21 0 0012.5 0c3.1-1 6.2-2.8 6.2-6.4V88H114v-1.6h18.8c3.6 0 5-2.6 6.2-6.4 1.3-3.9 1.3-7.6 0-12.7-.9-3.6-2.6-6.3-6.2-6.3z"/></g><g><path fill="#336791" d="M113.7 47.6c-2.2 0-4.2.2-6 .5-5.3 1-6.3 3-6.3 6.5v4.8H114V61H96.7a7.8 7.8 0 00-7.8 6.3A23.4 23.4 0 0089 80c1 3.7 3 6.4 6.7 6.4h4.3v-5.7a8 8 0 017.8-7.8h12.5c3.5 0 6.2-2.9 6.2-6.4V54.6c0-3.4-2.8-5.9-6.2-6.5a39 39 0 00-6.5-.5zm-6.8 3.9c1.3 0 2.3 1 2.3 2.4 0 1.3-1 2.3-2.3 2.3-1.3 0-2.3-1-2.3-2.3 0-1.4 1-2.4 2.3-2.4z"/><path fill="#ffc836" d="M128 61v5.5a8 8 0 01-7.8 8h-12.5c-3.4 0-6.3 2.9-6.3 6.3v12c0 3.3 3 5.3 6.3 6.3a21 21 0 0012.5 0c3.1-1 6.2-2.8 6.2-6.4V88H114v-1.6h18.8c3.6 0 5-2.6 6.2-6.4 1.3-3.9 1.3-7.6 0-12.7-.9-3.6-2.6-6.3-6.2-6.3zM121 91c1.3 0 2.3 1.1 2.3 2.4 0 1.3-1 2.4-2.3 2.4-1.3 0-2.4-1-2.4-2.4 0-1.3 1-2.4 2.4-2.4z" color="#000"/><path fill="#336791" d="M127.2 59.8h.7v.6h-.7z" color="#000"/></g></svg> \ No newline at end of file
diff --git a/docs/_templates/.keep b/docs/_templates/.keep
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/docs/_templates/.keep
diff --git a/docs/advanced/adapt.rst b/docs/advanced/adapt.rst
new file mode 100644
index 0000000..4323b07
--- /dev/null
+++ b/docs/advanced/adapt.rst
@@ -0,0 +1,269 @@
+.. currentmodule:: psycopg.adapt
+
+.. _adaptation:
+
+Data adaptation configuration
+=============================
+
+The adaptation system is at the core of Psycopg and allows to customise the
+way Python objects are converted to PostgreSQL when a query is performed and
+how PostgreSQL values are converted to Python objects when query results are
+returned.
+
+.. note::
+ For a high-level view of the conversion of types between Python and
+ PostgreSQL please look at :ref:`query-parameters`. Using the objects
+ described in this page is useful if you intend to *customise* the
+ adaptation rules.
+
+- Adaptation configuration is performed by changing the
+ `~psycopg.abc.AdaptContext.adapters` object of objects implementing the
+ `~psycopg.abc.AdaptContext` protocol, for instance `~psycopg.Connection`
+ or `~psycopg.Cursor`.
+
+- Every context object derived from another context inherits its adapters
+ mapping: cursors created from a connection inherit the connection's
+ configuration.
+
+ By default, connections obtain an adapters map from the global map
+ exposed as `psycopg.adapters`: changing the content of this object will
+ affect every connection created afterwards. You may specify a different
+ template adapters map using the `!context` parameter on
+ `~psycopg.Connection.connect()`.
+
+ .. image:: ../pictures/adapt.svg
+ :align: center
+
+- The `!adapters` attributes are `AdaptersMap` instances, and contain the
+ mapping from Python types and `~psycopg.abc.Dumper` classes, and from
+ PostgreSQL OIDs to `~psycopg.abc.Loader` classes. Changing this mapping
+ (e.g. writing and registering your own adapters, or using a different
+ configuration of builtin adapters) affects how types are converted between
+ Python and PostgreSQL.
+
+ - Dumpers (objects implementing the `~psycopg.abc.Dumper` protocol) are
+ the objects used to perform the conversion from a Python object to a bytes
+ sequence in a format understood by PostgreSQL. The string returned
+ *shouldn't be quoted*: the value will be passed to the database using
+ functions such as :pq:`PQexecParams()` so quoting and quotes escaping is
+ not necessary. The dumper usually also suggests to the server what type to
+ use, via its `~psycopg.abc.Dumper.oid` attribute.
+
+ - Loaders (objects implementing the `~psycopg.abc.Loader` protocol) are
+ the objects used to perform the opposite operation: reading a bytes
+ sequence from PostgreSQL and creating a Python object out of it.
+
+ - Dumpers and loaders are instantiated on demand by a `~Transformer` object
+ when a query is executed.
+
+.. note::
+ Changing adapters in a context only affects that context and its children
+ objects created *afterwards*; the objects already created are not
+ affected. For instance, changing the global context will only change newly
+ created connections, not the ones already existing.
+
+
+.. _adapt-example-xml:
+
+Writing a custom adapter: XML
+-----------------------------
+
+Psycopg doesn't provide adapters for the XML data type, because there are just
+too many ways of handling XML in Python. Creating a loader to parse the
+`PostgreSQL xml type`__ to `~xml.etree.ElementTree` is very simple, using the
+`psycopg.adapt.Loader` base class and implementing the
+`~psycopg.abc.Loader.load()` method:
+
+.. __: https://www.postgresql.org/docs/current/datatype-xml.html
+
+.. code:: python
+
+ >>> import xml.etree.ElementTree as ET
+ >>> from psycopg.adapt import Loader
+
+ >>> # Create a class implementing the `load()` method.
+ >>> class XmlLoader(Loader):
+ ... def load(self, data):
+ ... return ET.fromstring(data)
+
+ >>> # Register the loader on the adapters of a context.
+ >>> conn.adapters.register_loader("xml", XmlLoader)
+
+ >>> # Now just query the database returning XML data.
+ >>> cur = conn.execute(
+ ... """select XMLPARSE (DOCUMENT '<?xml version="1.0"?>
+ ... <book><title>Manual</title><chapter>...</chapter></book>')
+ ... """)
+
+ >>> elem = cur.fetchone()[0]
+ >>> elem
+ <Element 'book' at 0x7ffb55142ef0>
+
+The opposite operation, converting Python objects to PostgreSQL, is performed
+by dumpers. The `psycopg.adapt.Dumper` base class makes it easy to implement one:
+you only need to implement the `~psycopg.abc.Dumper.dump()` method::
+
+ >>> from psycopg.adapt import Dumper
+
+ >>> class XmlDumper(Dumper):
+ ... # Setting an OID is not necessary but can be helpful
+ ... oid = psycopg.adapters.types["xml"].oid
+ ...
+ ... def dump(self, elem):
+ ... return ET.tostring(elem)
+
+ >>> # Register the dumper on the adapters of a context
+ >>> conn.adapters.register_dumper(ET.Element, XmlDumper)
+
+ >>> # Now, in that context, it is possible to use ET.Element objects as parameters
+ >>> conn.execute("SELECT xpath('//title/text()', %s)", [elem]).fetchone()[0]
+ ['Manual']
+
+Note that it is possible to use a `~psycopg.types.TypesRegistry`, exposed by
+any `~psycopg.abc.AdaptContext`, to obtain information on builtin types, or
+extension types if they have been registered on that context using the
+`~psycopg.types.TypeInfo`\.\ `~psycopg.types.TypeInfo.register()` method.
+
+
+.. _adapt-example-float:
+
+Example: PostgreSQL numeric to Python float
+-------------------------------------------
+
+Normally PostgreSQL :sql:`numeric` values are converted to Python
+`~decimal.Decimal` instances, because both the types allow fixed-precision
+arithmetic and are not subject to rounding.
+
+Sometimes, however, you may want to perform floating-point math on
+:sql:`numeric` values, and `!Decimal` may get in the way (maybe because it is
+slower, or maybe because mixing `!float` and `!Decimal` values causes Python
+errors).
+
+If you are fine with the potential loss of precision and you simply want to
+receive :sql:`numeric` values as Python `!float`, you can register on
+:sql:`numeric` the same `Loader` class used to load
+:sql:`float4`\/:sql:`float8` values. Because the PostgreSQL textual
+representation of both floats and decimal is the same, the two loaders are
+compatible.
+
+.. code:: python
+
+ conn = psycopg.connect()
+
+ conn.execute("SELECT 123.45").fetchone()[0]
+ # Decimal('123.45')
+
+ conn.adapters.register_loader("numeric", psycopg.types.numeric.FloatLoader)
+
+ conn.execute("SELECT 123.45").fetchone()[0]
+ # 123.45
+
+In this example the customised adaptation takes effect only on the connection
+`!conn` and on any cursor created from it, not on other connections.
+
+
+.. _adapt-example-inf-date:
+
+Example: handling infinity date
+-------------------------------
+
+Suppose you want to work with the "infinity" date which is available in
+PostgreSQL but not handled by Python:
+
+.. code:: python
+
+ >>> conn.execute("SELECT 'infinity'::date").fetchone()
+ Traceback (most recent call last):
+ ...
+ DataError: date too large (after year 10K): 'infinity'
+
+One possibility would be to store Python's `datetime.date.max` as PostgreSQL
+infinity. For this, let's create a subclass for the dumper and the loader and
+register them in the working scope (globally or just on a connection or
+cursor):
+
+.. code:: python
+
+ from datetime import date
+
+ # Subclass existing adapters so that the base case is handled normally.
+ from psycopg.types.datetime import DateLoader, DateDumper
+
+ class InfDateDumper(DateDumper):
+ def dump(self, obj):
+ if obj == date.max:
+ return b"infinity"
+ elif obj == date.min:
+ return b"-infinity"
+ else:
+ return super().dump(obj)
+
+ class InfDateLoader(DateLoader):
+ def load(self, data):
+ if data == b"infinity":
+ return date.max
+ elif data == b"-infinity":
+ return date.min
+ else:
+ return super().load(data)
+
+ # The new classes can be registered globally, on a connection, on a cursor
+ cur.adapters.register_dumper(date, InfDateDumper)
+ cur.adapters.register_loader("date", InfDateLoader)
+
+ cur.execute("SELECT %s::text, %s::text", [date(2020, 12, 31), date.max]).fetchone()
+ # ('2020-12-31', 'infinity')
+ cur.execute("SELECT '2020-12-31'::date, 'infinity'::date").fetchone()
+ # (datetime.date(2020, 12, 31), datetime.date(9999, 12, 31))
+
+
+Dumpers and loaders life cycle
+------------------------------
+
+Registering dumpers and loaders will instruct Psycopg to use them
+in the queries to follow, in the context where they have been registered.
+
+When a query is performed on a `~psycopg.Cursor`, a
+`~psycopg.adapt.Transformer` object is created as a local context to manage
+adaptation during the query, instantiating the required dumpers and loaders
+and dispatching the values to perform the wanted conversions from Python to
+Postgres and back.
+
+- The `!Transformer` copies the adapters configuration from the `!Cursor`,
+ thus inheriting all the changes made to the global `psycopg.adapters`
+ configuration, the current `!Connection`, the `!Cursor`.
+
+- For every Python type passed as query argument, the `!Transformer` will
+ instantiate a `!Dumper`. Usually all the objects of the same type will be
+ converted by the same dumper instance.
+
+ - According to the placeholder used (``%s``, ``%b``, ``%t``), Psycopg may
+ pick a binary or a text dumper. When using the ``%s`` "`~PyFormat.AUTO`"
+ format, if the same type has both a text and a binary dumper registered,
+ the last one registered by `~AdaptersMap.register_dumper()` will be used.
+
+ - Sometimes, just looking at the Python type is not enough to decide the
+ best PostgreSQL type to use (for instance the PostgreSQL type of a Python
+ list depends on the objects it contains, whether to use an :sql:`integer`
+ or :sql:`bigint` depends on the number size...) In these cases the
+ mechanism provided by `~psycopg.abc.Dumper.get_key()` and
+ `~psycopg.abc.Dumper.upgrade()` is used to create more specific dumpers.
+
+- The query is executed. Upon successful request, the result is received as a
+ `~psycopg.pq.PGresult`.
+
+- For every OID returned by the query, the `!Transformer` will instantiate a
+ `!Loader`. All the values with the same OID will be converted by the same
+ loader instance.
+
+- Recursive types (e.g. Python lists, PostgreSQL arrays and composite types)
+ will use the same adaptation rules.
+
+As a consequence it is possible to perform certain choices only once per query
+(e.g. looking up the connection encoding) and then call a fast-path operation
+for each value to convert.
+
+Querying will fail if a Python object for which there isn't a `!Dumper`
+registered (for the right `~psycopg.pq.Format`) is used as query parameter.
+If the query returns a data type whose OID doesn't have a `!Loader`, the
+value will be returned as a string (or bytes string for binary types).
diff --git a/docs/advanced/async.rst b/docs/advanced/async.rst
new file mode 100644
index 0000000..3620ab6
--- /dev/null
+++ b/docs/advanced/async.rst
@@ -0,0 +1,360 @@
+.. currentmodule:: psycopg
+
+.. index:: asyncio
+
+.. _async:
+
+Asynchronous operations
+=======================
+
+Psycopg `~Connection` and `~Cursor` have counterparts `~AsyncConnection` and
+`~AsyncCursor` supporting an `asyncio` interface.
+
+The design of the asynchronous objects is pretty much the same of the sync
+ones: in order to use them you will only have to scatter the `!await` keyword
+here and there.
+
+.. code:: python
+
+ async with await psycopg.AsyncConnection.connect(
+ "dbname=test user=postgres") as aconn:
+ async with aconn.cursor() as acur:
+ await acur.execute(
+ "INSERT INTO test (num, data) VALUES (%s, %s)",
+ (100, "abc'def"))
+ await acur.execute("SELECT * FROM test")
+ await acur.fetchone()
+ # will return (1, 100, "abc'def")
+ async for record in acur:
+ print(record)
+
+.. versionchanged:: 3.1
+
+ `AsyncConnection.connect()` performs DNS name resolution in a non-blocking
+ way.
+
+ .. warning::
+
+ Before version 3.1, `AsyncConnection.connect()` may still block on DNS
+ name resolution. To avoid that you should `set the hostaddr connection
+ parameter`__, or use the `~psycopg._dns.resolve_hostaddr_async()` to
+ do it automatically.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-PARAMKEYWORDS
+
+.. warning::
+
+ On Windows, Psycopg is not compatible with the default
+ `~asyncio.ProactorEventLoop`. Please use a different loop, for instance
+ the `~asyncio.SelectorEventLoop`.
+
+ For instance, you can use, early in your program:
+
+ .. parsed-literal::
+
+ `asyncio.set_event_loop_policy`\ (
+ `asyncio.WindowsSelectorEventLoopPolicy`\ ()
+ )
+
+
+
+.. index:: with
+
+.. _async-with:
+
+`!with` async connections
+-------------------------
+
+As seen in :ref:`the basic usage <usage>`, connections and cursors can act as
+context managers, so you can run:
+
+.. code:: python
+
+ with psycopg.connect("dbname=test user=postgres") as conn:
+ with conn.cursor() as cur:
+ cur.execute(...)
+ # the cursor is closed upon leaving the context
+ # the transaction is committed, the connection closed
+
+For asynchronous connections it's *almost* what you'd expect, but
+not quite. Please note that `~Connection.connect()` and `~Connection.cursor()`
+*don't return a context*: they are both factory methods which return *an
+object which can be used as a context*. That's because there are several use
+cases where it's useful to handle the objects manually and only `!close()` them
+when required.
+
+As a consequence you cannot use `!async with connect()`: you have to do it in
+two steps instead, as in
+
+.. code:: python
+
+ aconn = await psycopg.AsyncConnection.connect()
+ async with aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute(...)
+
+which can be condensed into `!async with await`:
+
+.. code:: python
+
+ async with await psycopg.AsyncConnection.connect() as aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute(...)
+
+...but no less than that: you still need to do the double async thing.
+
+Note that the `AsyncConnection.cursor()` function is not an `!async` function
+(it never performs I/O), so you don't need an `!await` on it; as a consequence
+you can use the normal `async with` context manager.
+
+
+.. index:: Ctrl-C
+
+.. _async-ctrl-c:
+
+Interrupting async operations using Ctrl-C
+------------------------------------------
+
+If a long running operation is interrupted by a Ctrl-C on a normal connection
+running in the main thread, the operation will be cancelled and the connection
+will be put in error state, from which can be recovered with a normal
+`~Connection.rollback()`.
+
+If the query is running in an async connection, a Ctrl-C will be likely
+intercepted by the async loop and interrupt the whole program. In order to
+emulate what normally happens with blocking connections, you can use
+`asyncio's add_signal_handler()`__, to call `Connection.cancel()`:
+
+.. code:: python
+
+ import asyncio
+ import signal
+
+ async with await psycopg.AsyncConnection.connect() as conn:
+ loop.add_signal_handler(signal.SIGINT, conn.cancel)
+ ...
+
+
+.. __: https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.add_signal_handler
+
+
+.. index::
+ pair: Asynchronous; Notifications
+ pair: LISTEN; SQL command
+ pair: NOTIFY; SQL command
+
+.. _async-messages:
+
+Server messages
+---------------
+
+PostgreSQL can send, together with the query results, `informative messages`__
+about the operation just performed, such as warnings or debug information.
+Notices may be raised even if the operations are successful and don't indicate
+an error. You are probably familiar with some of them, because they are
+reported by :program:`psql`::
+
+ $ psql
+ =# ROLLBACK;
+ WARNING: there is no transaction in progress
+ ROLLBACK
+
+.. __: https://www.postgresql.org/docs/current/runtime-config-logging.html
+ #RUNTIME-CONFIG-SEVERITY-LEVELS
+
+Messages can be also sent by the `PL/pgSQL 'RAISE' statement`__ (at a level
+lower than EXCEPTION, otherwise the appropriate `DatabaseError` will be
+raised). The level of the messages received can be controlled using the
+client_min_messages__ setting.
+
+.. __: https://www.postgresql.org/docs/current/plpgsql-errors-and-messages.html
+.. __: https://www.postgresql.org/docs/current/runtime-config-client.html
+ #GUC-CLIENT-MIN-MESSAGES
+
+
+By default, the messages received are ignored. If you want to process them on
+the client you can use the `Connection.add_notice_handler()` function to
+register a function that will be invoked whenever a message is received. The
+message is passed to the callback as a `~errors.Diagnostic` instance,
+containing all the information passed by the server, such as the message text
+and the severity. The object is the same found on the `~psycopg.Error.diag`
+attribute of the errors raised by the server:
+
+.. code:: python
+
+ >>> import psycopg
+
+ >>> def log_notice(diag):
+ ... print(f"The server says: {diag.severity} - {diag.message_primary}")
+
+ >>> conn = psycopg.connect(autocommit=True)
+ >>> conn.add_notice_handler(log_notice)
+
+ >>> cur = conn.execute("ROLLBACK")
+ The server says: WARNING - there is no transaction in progress
+ >>> print(cur.statusmessage)
+ ROLLBACK
+
+.. warning::
+
+ The `!Diagnostic` object received by the callback should not be used after
+ the callback function terminates, because its data is deallocated after
+ the callbacks have been processed. If you need to use the information
+ later please extract the attributes requested and forward them instead of
+ forwarding the whole `!Diagnostic` object.
+
+
+.. index::
+ pair: Asynchronous; Notifications
+ pair: LISTEN; SQL command
+ pair: NOTIFY; SQL command
+
+.. _async-notify:
+
+Asynchronous notifications
+--------------------------
+
+Psycopg allows asynchronous interaction with other database sessions using the
+facilities offered by PostgreSQL commands |LISTEN|_ and |NOTIFY|_. Please
+refer to the PostgreSQL documentation for examples about how to use this form
+of communication.
+
+.. |LISTEN| replace:: :sql:`LISTEN`
+.. _LISTEN: https://www.postgresql.org/docs/current/sql-listen.html
+.. |NOTIFY| replace:: :sql:`NOTIFY`
+.. _NOTIFY: https://www.postgresql.org/docs/current/sql-notify.html
+
+Because of the way sessions interact with notifications (see |NOTIFY|_
+documentation), you should keep the connection in `~Connection.autocommit`
+mode if you wish to receive or send notifications in a timely manner.
+
+Notifications are received as instances of `Notify`. If you are reserving a
+connection only to receive notifications, the simplest way is to consume the
+`Connection.notifies` generator. The generator can be stopped using
+`!close()`.
+
+.. note::
+
+ You don't need an `AsyncConnection` to handle notifications: a normal
+ blocking `Connection` is perfectly valid.
+
+The following example will print notifications and stop when one containing
+the ``"stop"`` message is received.
+
+.. code:: python
+
+ import psycopg
+ conn = psycopg.connect("", autocommit=True)
+ conn.execute("LISTEN mychan")
+ gen = conn.notifies()
+ for notify in gen:
+ print(notify)
+ if notify.payload == "stop":
+ gen.close()
+ print("there, I stopped")
+
+If you run some :sql:`NOTIFY` in a :program:`psql` session:
+
+.. code:: psql
+
+ =# NOTIFY mychan, 'hello';
+ NOTIFY
+ =# NOTIFY mychan, 'hey';
+ NOTIFY
+ =# NOTIFY mychan, 'stop';
+ NOTIFY
+
+You may get output from the Python process such as::
+
+ Notify(channel='mychan', payload='hello', pid=961823)
+ Notify(channel='mychan', payload='hey', pid=961823)
+ Notify(channel='mychan', payload='stop', pid=961823)
+ there, I stopped
+
+Alternatively, you can use `~Connection.add_notify_handler()` to register a
+callback function, which will be invoked whenever a notification is received,
+during the normal query processing; you will be then able to use the
+connection normally. Please note that in this case notifications will not be
+received immediately, but only during a connection operation, such as a query.
+
+.. code:: python
+
+ conn.add_notify_handler(lambda n: print(f"got this: {n}"))
+
+ # meanwhile in psql...
+ # =# NOTIFY mychan, 'hey';
+ # NOTIFY
+
+ print(conn.execute("SELECT 1").fetchone())
+ # got this: Notify(channel='mychan', payload='hey', pid=961823)
+ # (1,)
+
+
+.. index:: disconnections
+
+.. _disconnections:
+
+Detecting disconnections
+------------------------
+
+Sometimes it is useful to detect immediately when the connection with the
+database is lost. One brutal way to do so is to poll a connection in a loop
+running an endless stream of :sql:`SELECT 1`... *Don't* do so: polling is *so*
+out of fashion. Besides, it is inefficient (unless what you really want is a
+client-server generator of ones), it generates useless traffic and will only
+detect a disconnection with an average delay of half the polling time.
+
+A more efficient and timely way to detect a server disconnection is to create
+an additional connection and wait for a notification from the OS that this
+connection has something to say: only then you can run some checks. You
+can dedicate a thread (or an asyncio task) to wait on this connection: such
+thread will perform no activity until awaken by the OS.
+
+In a normal (non asyncio) program you can use the `selectors` module. Because
+the `!Connection` implements a `~Connection.fileno()` method you can just
+register it as a file-like object. You can run such code in a dedicated thread
+(and using a dedicated connection) if the rest of the program happens to have
+something else to do too.
+
+.. code:: python
+
+ import selectors
+
+ sel = selectors.DefaultSelector()
+ sel.register(conn, selectors.EVENT_READ)
+ while True:
+ if not sel.select(timeout=60.0):
+ continue # No FD activity detected in one minute
+
+ # Activity detected. Is the connection still ok?
+ try:
+ conn.execute("SELECT 1")
+ except psycopg.OperationalError:
+ # You were disconnected: do something useful such as panicking
+ logger.error("we lost our database!")
+ sys.exit(1)
+
+In an `asyncio` program you can dedicate a `~asyncio.Task` instead and do
+something similar using `~asyncio.loop.add_reader`:
+
+.. code:: python
+
+ import asyncio
+
+ ev = asyncio.Event()
+ loop = asyncio.get_event_loop()
+ loop.add_reader(conn.fileno(), ev.set)
+
+ while True:
+ try:
+ await asyncio.wait_for(ev.wait(), 60.0)
+ except asyncio.TimeoutError:
+ continue # No FD activity detected in one minute
+
+ # Activity detected. Is the connection still ok?
+ try:
+ await conn.execute("SELECT 1")
+ except psycopg.OperationalError:
+ # Guess what happened
+ ...
diff --git a/docs/advanced/cursors.rst b/docs/advanced/cursors.rst
new file mode 100644
index 0000000..954d665
--- /dev/null
+++ b/docs/advanced/cursors.rst
@@ -0,0 +1,192 @@
+.. currentmodule:: psycopg
+
+.. index::
+ single: Cursor
+
+.. _cursor-types:
+
+Cursor types
+============
+
+Psycopg can manage kinds of "cursors" which differ in where the state of a
+query being processed is stored: :ref:`client-side-cursors` and
+:ref:`server-side-cursors`.
+
+.. index::
+ double: Cursor; Client-side
+
+.. _client-side-cursors:
+
+Client-side cursors
+-------------------
+
+Client-side cursors are what Psycopg uses in its normal querying process.
+They are implemented by the `Cursor` and `AsyncCursor` classes. In such
+querying pattern, after a cursor sends a query to the server (usually calling
+`~Cursor.execute()`), the server replies transferring to the client the whole
+set of results requested, which is stored in the state of the same cursor and
+from where it can be read from Python code (using methods such as
+`~Cursor.fetchone()` and siblings).
+
+This querying process is very scalable because, after a query result has been
+transmitted to the client, the server doesn't keep any state. Because the
+results are already in the client memory, iterating its rows is very quick.
+
+The downside of this querying method is that the entire result has to be
+transmitted completely to the client (with a time proportional to its size)
+and the client needs enough memory to hold it, so it is only suitable for
+reasonably small result sets.
+
+
+.. index::
+ double: Cursor; Client-binding
+
+.. _client-side-binding-cursors:
+
+Client-side-binding cursors
+---------------------------
+
+.. versionadded:: 3.1
+
+The previously described :ref:`client-side cursors <client-side-cursors>` send
+the query and the parameters separately to the server. This is the most
+efficient way to process parametrised queries and allows to build several
+features and optimizations. However, not all types of queries can be bound
+server-side; in particular no Data Definition Language query can. See
+:ref:`server-side-binding` for the description of these problems.
+
+The `ClientCursor` (and its `AsyncClientCursor` async counterpart) merge the
+query on the client and send the query and the parameters merged together to
+the server. This allows to parametrize any type of PostgreSQL statement, not
+only queries (:sql:`SELECT`) and Data Manipulation statements (:sql:`INSERT`,
+:sql:`UPDATE`, :sql:`DELETE`).
+
+Using `!ClientCursor`, Psycopg 3 behaviour will be more similar to `psycopg2`
+(which only implements client-side binding) and could be useful to port
+Psycopg 2 programs more easily to Psycopg 3. The objects in the `sql` module
+allow for greater flexibility (for instance to parametrize a table name too,
+not only values); however, for simple cases, a `!ClientCursor` could be the
+right object.
+
+In order to obtain `!ClientCursor` from a connection, you can set its
+`~Connection.cursor_factory` (at init time or changing its attribute
+afterwards):
+
+.. code:: python
+
+ from psycopg import connect, ClientCursor
+
+ conn = psycopg.connect(DSN, cursor_factory=ClientCursor)
+ cur = conn.cursor()
+ # <psycopg.ClientCursor [no result] [IDLE] (database=piro) at 0x7fd977ae2880>
+
+If you need to create a one-off client-side-binding cursor out of a normal
+connection, you can just use the `~ClientCursor` class passing the connection
+as argument.
+
+.. code:: python
+
+ conn = psycopg.connect(DSN)
+ cur = psycopg.ClientCursor(conn)
+
+.. warning::
+
+ Client-side cursors don't support :ref:`binary parameters and return
+ values <binary-data>` and don't support :ref:`prepared statements
+ <prepared-statements>`.
+
+.. tip::
+
+ The best use for client-side binding cursors is probably to port large
+ Psycopg 2 code to Psycopg 3, especially for programs making wide use of
+ Data Definition Language statements.
+
+ The `psycopg.sql` module allows for more generic client-side query
+ composition, to mix client- and server-side parameters binding, and allows
+ to parametrize tables and fields names too, or entirely generic SQL
+ snippets.
+
+.. index::
+ double: Cursor; Server-side
+ single: Portal
+ double: Cursor; Named
+
+.. _server-side-cursors:
+
+Server-side cursors
+-------------------
+
+PostgreSQL has its own concept of *cursor* too (sometimes also called
+*portal*). When a database cursor is created, the query is not necessarily
+completely processed: the server might be able to produce results only as they
+are needed. Only the results requested are transmitted to the client: if the
+query result is very large but the client only needs the first few records it
+is possible to transmit only them.
+
+The downside is that the server needs to keep track of the partially
+processed results, so it uses more memory and resources on the server.
+
+Psycopg allows the use of server-side cursors using the classes `ServerCursor`
+and `AsyncServerCursor`. They are usually created by passing the `!name`
+parameter to the `~Connection.cursor()` method (reason for which, in
+`!psycopg2`, they are usually called *named cursors*). The use of these classes
+is similar to their client-side counterparts: their interface is the same, but
+behind the scene they send commands to control the state of the cursor on the
+server (for instance when fetching new records or when moving using
+`~Cursor.scroll()`).
+
+Using a server-side cursor it is possible to process datasets larger than what
+would fit in the client's memory. However for small queries they are less
+efficient because it takes more commands to receive their result, so you
+should use them only if you need to process huge results or if only a partial
+result is needed.
+
+.. seealso::
+
+ Server-side cursors are created and managed by `ServerCursor` using SQL
+ commands such as DECLARE_, FETCH_, MOVE_. The PostgreSQL documentation
+ gives a good idea of what is possible to do with them.
+
+ .. _DECLARE: https://www.postgresql.org/docs/current/sql-declare.html
+ .. _FETCH: https://www.postgresql.org/docs/current/sql-fetch.html
+ .. _MOVE: https://www.postgresql.org/docs/current/sql-move.html
+
+
+.. _cursor-steal:
+
+"Stealing" an existing cursor
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+A Psycopg `ServerCursor` can be also used to consume a cursor which was
+created in other ways than the :sql:`DECLARE` that `ServerCursor.execute()`
+runs behind the scene.
+
+For instance if you have a `PL/pgSQL function returning a cursor`__:
+
+.. __: https://www.postgresql.org/docs/current/plpgsql-cursors.html
+
+.. code:: postgres
+
+ CREATE FUNCTION reffunc(refcursor) RETURNS refcursor AS $$
+ BEGIN
+ OPEN $1 FOR SELECT col FROM test;
+ RETURN $1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+you can run a one-off command in the same connection to call it (e.g. using
+`Connection.execute()`) in order to create the cursor on the server:
+
+.. code:: python
+
+ conn.execute("SELECT reffunc('curname')")
+
+after which you can create a server-side cursor declared by the same name, and
+directly call the fetch methods, skipping the `~ServerCursor.execute()` call:
+
+.. code:: python
+
+ cur = conn.cursor('curname')
+ # no cur.execute()
+ for record in cur: # or cur.fetchone(), cur.fetchmany()...
+ # do something with record
diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst
new file mode 100644
index 0000000..6920bd7
--- /dev/null
+++ b/docs/advanced/index.rst
@@ -0,0 +1,21 @@
+.. _advanced:
+
+More advanced topics
+====================
+
+Once you have familiarised yourself with the :ref:`Psycopg basic operations
+<basic>`, you can take a look at the chapter of this section for more advanced
+usages.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+ async
+ typing
+ rows
+ pool
+ cursors
+ adapt
+ prepare
+ pipeline
diff --git a/docs/advanced/pipeline.rst b/docs/advanced/pipeline.rst
new file mode 100644
index 0000000..980fea7
--- /dev/null
+++ b/docs/advanced/pipeline.rst
@@ -0,0 +1,324 @@
+.. currentmodule:: psycopg
+
+.. _pipeline-mode:
+
+Pipeline mode support
+=====================
+
+.. versionadded:: 3.1
+
+The *pipeline mode* allows PostgreSQL client applications to send a query
+without having to read the result of the previously sent query. Taking
+advantage of the pipeline mode, a client will wait less for the server, since
+multiple queries/results can be sent/received in a single network roundtrip.
+Pipeline mode can provide a significant performance boost to the application.
+
+Pipeline mode is most useful when the server is distant, i.e., network latency
+(“ping time”) is high, and also when many small operations are being performed
+in rapid succession. There is usually less benefit in using pipelined commands
+when each query takes many multiples of the client/server round-trip time to
+execute. A 100-statement operation run on a server 300 ms round-trip-time away
+would take 30 seconds in network latency alone without pipelining; with
+pipelining it may spend as little as 0.3 s waiting for results from the
+server.
+
+The server executes statements, and returns results, in the order the client
+sends them. The server will begin executing the commands in the pipeline
+immediately, not waiting for the end of the pipeline. Note that results are
+buffered on the server side; the server flushes that buffer when a
+:ref:`synchronization point <pipeline-sync>` is established.
+
+.. seealso::
+
+ The PostgreSQL documentation about:
+
+ - `pipeline mode`__
+ - `extended query message flow`__
+
+ contains many details around when it is most useful to use the pipeline
+ mode and about errors management and interaction with transactions.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-pipeline-mode.html
+ .. __: https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
+
+
+Client-server messages flow
+---------------------------
+
+In order to understand better how the pipeline mode works, we should take a
+closer look at the `PostgreSQL client-server message flow`__.
+
+During normal querying, each statement is transmitted by the client to the
+server as a stream of request messages, terminating with a **Sync** message to
+tell it that it should process the messages sent so far. The server will
+execute the statement and describe the results back as a stream of messages,
+terminating with a **ReadyForQuery**, telling the client that it may now send a
+new query.
+
+For example, the statement (returning no result):
+
+.. code:: python
+
+ conn.execute("INSERT INTO mytable (data) VALUES (%s)", ["hello"])
+
+results in the following two groups of messages:
+
+.. table::
+ :align: left
+
+ +---------------+-----------------------------------------------------------+
+ | Direction | Message |
+ +===============+===========================================================+
+ | Python | - Parse ``INSERT INTO ... (VALUE $1)`` (skipped if |
+ | | :ref:`the statement is prepared <prepared-statements>`) |
+ | |>| | - Bind ``'hello'`` |
+ | | - Describe |
+ | PostgreSQL | - Execute |
+ | | - Sync |
+ +---------------+-----------------------------------------------------------+
+ | PostgreSQL | - ParseComplete |
+ | | - BindComplete |
+ | |<| | - NoData |
+ | | - CommandComplete ``INSERT 0 1`` |
+ | Python | - ReadyForQuery |
+ +---------------+-----------------------------------------------------------+
+
+and the query:
+
+.. code:: python
+
+ conn.execute("SELECT data FROM mytable WHERE id = %s", [1])
+
+results in the two groups of messages:
+
+.. table::
+ :align: left
+
+ +---------------+-----------------------------------------------------------+
+ | Direction | Message |
+ +===============+===========================================================+
+ | Python | - Parse ``SELECT data FROM mytable WHERE id = $1`` |
+ | | - Bind ``1`` |
+ | |>| | - Describe |
+ | | - Execute |
+ | PostgreSQL | - Sync |
+ +---------------+-----------------------------------------------------------+
+ | PostgreSQL | - ParseComplete |
+ | | - BindComplete |
+ | |<| | - RowDescription ``data`` |
+ | | - DataRow ``hello`` |
+ | Python | - CommandComplete ``SELECT 1`` |
+ | | - ReadyForQuery |
+ +---------------+-----------------------------------------------------------+
+
+The two statements, sent consecutively, pay the communication overhead four
+times, once per leg.
+
+The pipeline mode allows the client to combine several operations in longer
+streams of messages to the server, then to receive more than one response in a
+single batch. If we execute the two operations above in a pipeline:
+
+.. code:: python
+
+ with conn.pipeline():
+ conn.execute("INSERT INTO mytable (data) VALUES (%s)", ["hello"])
+ conn.execute("SELECT data FROM mytable WHERE id = %s", [1])
+
+they will result in a single roundtrip between the client and the server:
+
+.. table::
+ :align: left
+
+ +---------------+-----------------------------------------------------------+
+ | Direction | Message |
+ +===============+===========================================================+
+ | Python | - Parse ``INSERT INTO ... (VALUE $1)`` |
+ | | - Bind ``'hello'`` |
+ | |>| | - Describe |
+ | | - Execute |
+ | PostgreSQL | - Parse ``SELECT data FROM mytable WHERE id = $1`` |
+ | | - Bind ``1`` |
+ | | - Describe |
+ | | - Execute |
+ | | - Sync (sent only once) |
+ +---------------+-----------------------------------------------------------+
+ | PostgreSQL | - ParseComplete |
+ | | - BindComplete |
+ | |<| | - NoData |
+ | | - CommandComplete ``INSERT 0 1`` |
+ | Python | - ParseComplete |
+ | | - BindComplete |
+ | | - RowDescription ``data`` |
+ | | - DataRow ``hello`` |
+ | | - CommandComplete ``SELECT 1`` |
+ | | - ReadyForQuery (sent only once) |
+ +---------------+-----------------------------------------------------------+
+
+.. |<| unicode:: U+25C0
+.. |>| unicode:: U+25B6
+
+.. __: https://www.postgresql.org/docs/current/protocol-flow.html
+
+
+.. _pipeline-usage:
+
+Pipeline mode usage
+-------------------
+
+Psycopg supports the pipeline mode via the `Connection.pipeline()` method. The
+method is a context manager: entering the ``with`` block yields a `Pipeline`
+object. At the end of block, the connection resumes the normal operation mode.
+
+Within the pipeline block, you can use normally one or more cursors to execute
+several operations, using `Connection.execute()`, `Cursor.execute()` and
+`~Cursor.executemany()`.
+
+.. code:: python
+
+ >>> with conn.pipeline():
+ ... conn.execute("INSERT INTO mytable VALUES (%s)", ["hello"])
+ ... with conn.cursor() as cur:
+ ... cur.execute("INSERT INTO othertable VALUES (%s)", ["world"])
+ ... cur.executemany(
+ ... "INSERT INTO elsewhere VALUES (%s)",
+ ... [("one",), ("two",), ("four",)])
+
+Unlike in normal mode, Psycopg will not wait for the server to receive the
+result of each query; the client will receive results in batches when the
+server flushes it output buffer.
+
+When a flush (or a sync) is performed, all pending results are sent back to
+the cursors which executed them. If a cursor had run more than one query, it
+will receive more than one result; results after the first will be available,
+in their execution order, using `~Cursor.nextset()`:
+
+.. code:: python
+
+ >>> with conn.pipeline():
+ ... with conn.cursor() as cur:
+ ... cur.execute("INSERT INTO mytable (data) VALUES (%s) RETURNING *", ["hello"])
+ ... cur.execute("INSERT INTO mytable (data) VALUES (%s) RETURNING *", ["world"])
+ ... while True:
+ ... print(cur.fetchall())
+ ... if not cur.nextset():
+ ... break
+
+ [(1, 'hello')]
+ [(2, 'world')]
+
+If any statement encounters an error, the server aborts the current
+transaction and will not execute any subsequent command in the queue until the
+next :ref:`synchronization point <pipeline-sync>`; a `~errors.PipelineAborted`
+exception is raised for each such command. Query processing resumes after the
+synchronization point.
+
+.. warning::
+
+ Certain features are not available in pipeline mode, including:
+
+ - COPY is not supported in pipeline mode by PostgreSQL.
+ - `Cursor.stream()` doesn't make sense in pipeline mode (its job is the
+ opposite of batching!)
+ - `ServerCursor` are currently not implemented in pipeline mode.
+
+.. note::
+
+ Starting from Psycopg 3.1, `~Cursor.executemany()` makes use internally of
+ the pipeline mode; as a consequence there is no need to handle a pipeline
+ block just to call `!executemany()` once.
+
+
+.. _pipeline-sync:
+
+Synchronization points
+----------------------
+
+Flushing query results to the client can happen either when a synchronization
+point is established by Psycopg:
+
+- using the `Pipeline.sync()` method;
+- on `Connection.commit()` or `~Connection.rollback()`;
+- at the end of a `!Pipeline` block;
+- possibly when opening a nested `!Pipeline` block;
+- using a fetch method such as `Cursor.fetchone()` (which only flushes the
+ query but doesn't issue a Sync and doesn't reset a pipeline state error).
+
+The server might perform a flush on its own initiative, for instance when the
+output buffer is full.
+
+Note that, even in :ref:`autocommit <autocommit>`, the server wraps the
+statements sent in pipeline mode in an implicit transaction, which will be
+only committed when the Sync is received. As such, a failure in a group of
+statements will probably invalidate the effect of statements executed after
+the previous Sync, and will propagate to the following Sync.
+
+For example, in the following block:
+
+.. code:: python
+
+ >>> with psycopg.connect(autocommit=True) as conn:
+ ... with conn.pipeline() as p, conn.cursor() as cur:
+ ... try:
+ ... cur.execute("INSERT INTO mytable (data) VALUES (%s)", ["one"])
+ ... cur.execute("INSERT INTO no_such_table (data) VALUES (%s)", ["two"])
+ ... conn.execute("INSERT INTO mytable (data) VALUES (%s)", ["three"])
+ ... p.sync()
+ ... except psycopg.errors.UndefinedTable:
+ ... pass
+ ... cur.execute("INSERT INTO mytable (data) VALUES (%s)", ["four"])
+
+there will be an error in the block, ``relation "no_such_table" does not
+exist`` caused by the insert ``two``, but probably raised by the `!sync()`
+call. At at the end of the block, the table will contain:
+
+.. code:: text
+
+ =# SELECT * FROM mytable;
+ +----+------+
+ | id | data |
+ +----+------+
+ | 2 | four |
+ +----+------+
+ (1 row)
+
+because:
+
+- the value 1 of the sequence is consumed by the statement ``one``, but
+ the record discarded because of the error in the same implicit transaction;
+- the statement ``three`` is not executed because the pipeline is aborted (so
+ it doesn't consume a sequence item);
+- the statement ``four`` is executed with
+ success after the Sync has terminated the failed transaction.
+
+.. warning::
+
+ The exact Python statement where an exception caused by a server error is
+ raised is somewhat arbitrary: it depends on when the server flushes its
+ buffered result.
+
+ If you want to make sure that a group of statements is applied atomically
+ by the server, do make use of transaction methods such as
+ `~Connection.commit()` or `~Connection.transaction()`: these methods will
+ also sync the pipeline and raise an exception if there was any error in
+ the commands executed so far.
+
+
+The fine prints
+---------------
+
+.. warning::
+
+ The Pipeline mode is an experimental feature.
+
+ Its behaviour, especially around error conditions and concurrency, hasn't
+ been explored as much as the normal request-response messages pattern, and
+ its async nature makes it inherently more complex.
+
+ As we gain more experience and feedback (which is welcome), we might find
+ bugs and shortcomings forcing us to change the current interface or
+ behaviour.
+
+The pipeline mode is available on any currently supported PostgreSQL version,
+but, in order to make use of it, the client must use a libpq from PostgreSQL
+14 or higher. You can use `Pipeline.is_supported()` to make sure your client
+has the right library.
diff --git a/docs/advanced/pool.rst b/docs/advanced/pool.rst
new file mode 100644
index 0000000..adea0a7
--- /dev/null
+++ b/docs/advanced/pool.rst
@@ -0,0 +1,332 @@
+.. currentmodule:: psycopg_pool
+
+.. _connection-pools:
+
+Connection pools
+================
+
+A `connection pool`__ is an object managing a set of connections and allowing
+their use in functions needing one. Because the time to establish a new
+connection can be relatively long, keeping connections open can reduce latency.
+
+.. __: https://en.wikipedia.org/wiki/Connection_pool
+
+This page explains a few basic concepts of Psycopg connection pool's
+behaviour. Please refer to the `ConnectionPool` object API for details about
+the pool operations.
+
+.. note:: The connection pool objects are distributed in a package separate
+ from the main `psycopg` package: use ``pip install "psycopg[pool]"`` or ``pip
+ install psycopg_pool`` to make the `psycopg_pool` package available. See
+ :ref:`pool-installation`.
+
+
+Pool life cycle
+---------------
+
+A simple way to use the pool is to create a single instance of it, as a
+global object, and to use this object in the rest of the program, allowing
+other functions, modules, threads to use it::
+
+ # module db.py in your program
+ from psycopg_pool import ConnectionPool
+
+ pool = ConnectionPool(conninfo, **kwargs)
+ # the pool starts connecting immediately.
+
+ # in another module
+ from .db import pool
+
+ def my_function():
+ with pool.connection() as conn:
+ conn.execute(...)
+
+Ideally you may want to call `~ConnectionPool.close()` when the use of the
+pool is finished. Failing to call `!close()` at the end of the program is not
+terribly bad: probably it will just result in some warnings printed on stderr.
+However, if you think that it's sloppy, you could use the `atexit` module to
+have `!close()` called at the end of the program.
+
+If you want to avoid starting to connect to the database at import time, and
+want to wait for the application to be ready, you can create the pool using
+`!open=False`, and call the `~ConnectionPool.open()` and
+`~ConnectionPool.close()` methods when the conditions are right. Certain
+frameworks provide callbacks triggered when the program is started and stopped
+(for instance `FastAPI startup/shutdown events`__): they are perfect to
+initiate and terminate the pool operations::
+
+ pool = ConnectionPool(conninfo, open=False, **kwargs)
+
+ @app.on_event("startup")
+ def open_pool():
+ pool.open()
+
+ @app.on_event("shutdown")
+ def close_pool():
+ pool.close()
+
+.. __: https://fastapi.tiangolo.com/advanced/events/#events-startup-shutdown
+
+Creating a single pool as a global variable is not the mandatory use: your
+program can create more than one pool, which might be useful to connect to
+more than one database, or to provide different types of connections, for
+instance to provide separate read/write and read-only connections. The pool
+also acts as a context manager and is open and closed, if necessary, on
+entering and exiting the context block::
+
+ from psycopg_pool import ConnectionPool
+
+ with ConnectionPool(conninfo, **kwargs) as pool:
+ run_app(pool)
+
+ # the pool is now closed
+
+When the pool is open, the pool's background workers start creating the
+requested `!min_size` connections, while the constructor (or the `!open()`
+method) returns immediately. This allows the program some leeway to start
+before the target database is up and running. However, if your application is
+misconfigured, or the network is down, it means that the program will be able
+to start, but the threads requesting a connection will fail with a
+`PoolTimeout` only after the timeout on `~ConnectionPool.connection()` is
+expired. If this behaviour is not desirable (and you prefer your program to
+crash hard and fast, if the surrounding conditions are not right, because
+something else will respawn it) you should call the `~ConnectionPool.wait()`
+method after creating the pool, or call `!open(wait=True)`: these methods will
+block until the pool is full, or will raise a `PoolTimeout` exception if the
+pool isn't ready within the allocated time.
+
+
+Connections life cycle
+----------------------
+
+The pool background workers create connections according to the parameters
+`!conninfo`, `!kwargs`, and `!connection_class` passed to `ConnectionPool`
+constructor, invoking something like :samp:`{connection_class}({conninfo},
+**{kwargs})`. Once a connection is created it is also passed to the
+`!configure()` callback, if provided, after which it is put in the pool (or
+passed to a client requesting it, if someone is already knocking at the door).
+
+If a connection expires (it passes `!max_lifetime`), or is returned to the pool
+in broken state, or is found closed by `~ConnectionPool.check()`), then the
+pool will dispose of it and will start a new connection attempt in the
+background.
+
+
+Using connections from the pool
+-------------------------------
+
+The pool can be used to request connections from multiple threads or
+concurrent tasks - it is hardly useful otherwise! If more connections than the
+ones available in the pool are requested, the requesting threads are queued
+and are served a connection as soon as one is available, either because
+another client has finished using it or because the pool is allowed to grow
+(when `!max_size` > `!min_size`) and a new connection is ready.
+
+The main way to use the pool is to obtain a connection using the
+`~ConnectionPool.connection()` context, which returns a `~psycopg.Connection`
+or subclass::
+
+ with my_pool.connection() as conn:
+ conn.execute("what you want")
+
+The `!connection()` context behaves like the `~psycopg.Connection` object
+context: at the end of the block, if there is a transaction open, it will be
+committed, or rolled back if the context is exited with as exception.
+
+At the end of the block the connection is returned to the pool and shouldn't
+be used anymore by the code which obtained it. If a `!reset()` function is
+specified in the pool constructor, it is called on the connection before
+returning it to the pool. Note that the `!reset()` function is called in a
+worker thread, so that the thread which used the connection can keep its
+execution without being slowed down by it.
+
+
+Pool connection and sizing
+--------------------------
+
+A pool can have a fixed size (specifying no `!max_size` or `!max_size` =
+`!min_size`) or a dynamic size (when `!max_size` > `!min_size`). In both
+cases, as soon as the pool is created, it will try to acquire `!min_size`
+connections in the background.
+
+If an attempt to create a connection fails, a new attempt will be made soon
+after, using an exponential backoff to increase the time between attempts,
+until a maximum of `!reconnect_timeout` is reached. When that happens, the pool
+will call the `!reconnect_failed()` function, if provided to the pool, and just
+start a new connection attempt. You can use this function either to send
+alerts or to interrupt the program and allow the rest of your infrastructure
+to restart it.
+
+If more than `!min_size` connections are requested concurrently, new ones are
+created, up to `!max_size`. Note that the connections are always created by the
+background workers, not by the thread asking for the connection: if a client
+requests a new connection, and a previous client terminates its job before the
+new connection is ready, the waiting client will be served the existing
+connection. This is especially useful in scenarios where the time to establish
+a connection dominates the time for which the connection is used (see `this
+analysis`__, for instance).
+
+.. __: https://github.com/brettwooldridge/HikariCP/blob/dev/documents/
+ Welcome-To-The-Jungle.md
+
+If a pool grows above `!min_size`, but its usage decreases afterwards, a number
+of connections are eventually closed: one every time a connection is unused
+after the `!max_idle` time specified in the pool constructor.
+
+
+What's the right size for the pool?
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Big question. Who knows. However, probably not as large as you imagine. Please
+take a look at `this analysis`__ for some ideas.
+
+.. __: https://github.com/brettwooldridge/HikariCP/wiki/About-Pool-Sizing
+
+Something useful you can do is probably to use the
+`~ConnectionPool.get_stats()` method and monitor the behaviour of your program
+to tune the configuration parameters. The size of the pool can also be changed
+at runtime using the `~ConnectionPool.resize()` method.
+
+
+.. _null-pool:
+
+Null connection pools
+---------------------
+
+.. versionadded:: 3.1
+
+Sometimes you may want leave the choice of using or not using a connection
+pool as a configuration parameter of your application. For instance, you might
+want to use a pool if you are deploying a "large instance" of your application
+and can dedicate it a handful of connections; conversely you might not want to
+use it if you deploy the application in several instances, behind a load
+balancer, and/or using an external connection pool process such as PgBouncer.
+
+Switching between using or not using a pool requires some code change, because
+the `ConnectionPool` API is different from the normal `~psycopg.connect()`
+function and because the pool can perform additional connection configuration
+(in the `!configure` parameter) that, if the pool is removed, should be
+performed in some different code path of your application.
+
+The `!psycopg_pool` 3.1 package introduces the `NullConnectionPool` class.
+This class has the same interface, and largely the same behaviour, of the
+`!ConnectionPool`, but doesn't create any connection beforehand. When a
+connection is returned, unless there are other clients already waiting, it
+is closed immediately and not kept in the pool state.
+
+A null pool is not only a configuration convenience, but can also be used to
+regulate the access to the server by a client program. If `!max_size` is set to
+a value greater than 0, the pool will make sure that no more than `!max_size`
+connections are created at any given time. If more clients ask for further
+connections, they will be queued and served a connection as soon as a previous
+client has finished using it, like for the basic pool. Other mechanisms to
+throttle client requests (such as `!timeout` or `!max_waiting`) are respected
+too.
+
+.. note::
+
+ Queued clients will be handed an already established connection, as soon
+ as a previous client has finished using it (and after the pool has
+ returned it to idle state and called `!reset()` on it, if necessary).
+
+Because normally (i.e. unless queued) every client will be served a new
+connection, the time to obtain the connection is paid by the waiting client;
+background workers are not normally involved in obtaining new connections.
+
+
+Connection quality
+------------------
+
+The state of the connection is verified when a connection is returned to the
+pool: if a connection is broken during its usage it will be discarded on
+return and a new connection will be created.
+
+.. warning::
+
+ The health of the connection is not checked when the pool gives it to a
+ client.
+
+Why not? Because doing so would require an extra network roundtrip: we want to
+save you from its latency. Before getting too angry about it, just think that
+the connection can be lost any moment while your program is using it. As your
+program should already be able to cope with a loss of a connection during its
+process, it should be able to tolerate to be served a broken connection:
+unpleasant but not the end of the world.
+
+.. warning::
+
+ The health of the connection is not checked when the connection is in the
+ pool.
+
+Does the pool keep a watchful eye on the quality of the connections inside it?
+No, it doesn't. Why not? Because you will do it for us! Your program is only
+a big ruse to make sure the connections are still alive...
+
+Not (entirely) trolling: if you are using a connection pool, we assume that
+you are using and returning connections at a good pace. If the pool had to
+check for the quality of a broken connection before your program notices it,
+it should be polling each connection even faster than your program uses them.
+Your database server wouldn't be amused...
+
+Can you do something better than that? Of course you can, there is always a
+better way than polling. You can use the same recipe of :ref:`disconnections`,
+reserving a connection and using a thread to monitor for any activity
+happening on it. If any activity is detected, you can call the pool
+`~ConnectionPool.check()` method, which will run a quick check on each
+connection in the pool, removing the ones found in broken state, and using the
+background workers to replace them with fresh ones.
+
+If you set up a similar check in your program, in case the database connection
+is temporarily lost, we cannot do anything for the threads which had taken
+already a connection from the pool, but no other thread should be served a
+broken connection, because `!check()` would empty the pool and refill it with
+working connections, as soon as they are available.
+
+Faster than you can say poll. Or pool.
+
+
+.. _pool-stats:
+
+Pool stats
+----------
+
+The pool can return information about its usage using the methods
+`~ConnectionPool.get_stats()` or `~ConnectionPool.pop_stats()`. Both methods
+return the same values, but the latter reset the counters after its use. The
+values can be sent to a monitoring system such as Graphite_ or Prometheus_.
+
+.. _Graphite: https://graphiteapp.org/
+.. _Prometheus: https://prometheus.io/
+
+The following values should be provided, but please don't consider them as a
+rigid interface: it is possible that they might change in the future. Keys
+whose value is 0 may not be returned.
+
+
+======================= =====================================================
+Metric Meaning
+======================= =====================================================
+ ``pool_min`` Current value for `~ConnectionPool.min_size`
+ ``pool_max`` Current value for `~ConnectionPool.max_size`
+ ``pool_size`` Number of connections currently managed by the pool
+ (in the pool, given to clients, being prepared)
+ ``pool_available`` Number of connections currently idle in the pool
+ ``requests_waiting`` Number of requests currently waiting in a queue to
+ receive a connection
+ ``usage_ms`` Total usage time of the connections outside the pool
+ ``requests_num`` Number of connections requested to the pool
+ ``requests_queued`` Number of requests queued because a connection wasn't
+ immediately available in the pool
+ ``requests_wait_ms`` Total time in the queue for the clients waiting
+ ``requests_errors`` Number of connection requests resulting in an error
+ (timeouts, queue full...)
+ ``returns_bad`` Number of connections returned to the pool in a bad
+ state
+ ``connections_num`` Number of connection attempts made by the pool to the
+ server
+ ``connections_ms`` Total time spent to establish connections with the
+ server
+ ``connections_errors`` Number of failed connection attempts
+ ``connections_lost`` Number of connections lost identified by
+ `~ConnectionPool.check()`
+======================= =====================================================
diff --git a/docs/advanced/prepare.rst b/docs/advanced/prepare.rst
new file mode 100644
index 0000000..e41bcae
--- /dev/null
+++ b/docs/advanced/prepare.rst
@@ -0,0 +1,57 @@
+.. currentmodule:: psycopg
+
+.. index::
+ single: Prepared statements
+
+.. _prepared-statements:
+
+Prepared statements
+===================
+
+Psycopg uses an automatic system to manage *prepared statements*. When a
+query is prepared, its parsing and planning is stored in the server session,
+so that further executions of the same query on the same connection (even with
+different parameters) are optimised.
+
+A query is prepared automatically after it is executed more than
+`~Connection.prepare_threshold` times on a connection. `!psycopg` will make
+sure that no more than `~Connection.prepared_max` statements are planned: if
+further queries are executed, the least recently used ones are deallocated and
+the associated resources freed.
+
+Statement preparation can be controlled in several ways:
+
+- You can decide to prepare a query immediately by passing `!prepare=True` to
+ `Connection.execute()` or `Cursor.execute()`. The query is prepared, if it
+ wasn't already, and executed as prepared from its first use.
+
+- Conversely, passing `!prepare=False` to `!execute()` will avoid to prepare
+ the query, regardless of the number of times it is executed. The default for
+ the parameter is `!None`, meaning that the query is prepared if the
+ conditions described above are met.
+
+- You can disable the use of prepared statements on a connection by setting
+ its `~Connection.prepare_threshold` attribute to `!None`.
+
+.. versionchanged:: 3.1
+ You can set `!prepare_threshold` as a `~Connection.connect()` keyword
+ parameter too.
+
+.. seealso::
+
+ The `PREPARE`__ PostgreSQL documentation contains plenty of details about
+ prepared statements in PostgreSQL.
+
+ Note however that Psycopg doesn't use SQL statements such as
+ :sql:`PREPARE` and :sql:`EXECUTE`, but protocol level commands such as the
+ ones exposed by :pq:`PQsendPrepare`, :pq:`PQsendQueryPrepared`.
+
+ .. __: https://www.postgresql.org/docs/current/sql-prepare.html
+
+.. warning::
+
+ Using external connection poolers, such as PgBouncer, is not compatible
+ with prepared statements, because the same client connection may change
+ the server session it refers to. If such middleware is used you should
+ disable prepared statements, by setting the `Connection.prepare_threshold`
+ attribute to `!None`.
diff --git a/docs/advanced/rows.rst b/docs/advanced/rows.rst
new file mode 100644
index 0000000..c23efe5
--- /dev/null
+++ b/docs/advanced/rows.rst
@@ -0,0 +1,116 @@
+.. currentmodule:: psycopg
+
+.. index:: row factories
+
+.. _row-factories:
+
+Row factories
+=============
+
+Cursor's `fetch*` methods, by default, return the records received from the
+database as tuples. This can be changed to better suit the needs of the
+programmer by using custom *row factories*.
+
+The module `psycopg.rows` exposes several row factories ready to be used. For
+instance, if you want to return your records as dictionaries, you can use
+`~psycopg.rows.dict_row`::
+
+ >>> from psycopg.rows import dict_row
+
+ >>> conn = psycopg.connect(DSN, row_factory=dict_row)
+
+ >>> conn.execute("select 'John Doe' as name, 33 as age").fetchone()
+ {'name': 'John Doe', 'age': 33}
+
+The `!row_factory` parameter is supported by the `~Connection.connect()`
+method and the `~Connection.cursor()` method. Later usage of `!row_factory`
+overrides a previous one. It is also possible to change the
+`Connection.row_factory` or `Cursor.row_factory` attributes to change what
+they return::
+
+ >>> cur = conn.cursor(row_factory=dict_row)
+ >>> cur.execute("select 'John Doe' as name, 33 as age").fetchone()
+ {'name': 'John Doe', 'age': 33}
+
+ >>> from psycopg.rows import namedtuple_row
+ >>> cur.row_factory = namedtuple_row
+ >>> cur.execute("select 'John Doe' as name, 33 as age").fetchone()
+ Row(name='John Doe', age=33)
+
+If you want to return objects of your choice you can use a row factory
+*generator*, for instance `~psycopg.rows.class_row` or
+`~psycopg.rows.args_row`, or you can :ref:`write your own row factory
+<row-factory-create>`::
+
+ >>> from dataclasses import dataclass
+
+ >>> @dataclass
+ ... class Person:
+ ... name: str
+ ... age: int
+ ... weight: Optional[int] = None
+
+ >>> from psycopg.rows import class_row
+ >>> cur = conn.cursor(row_factory=class_row(Person))
+ >>> cur.execute("select 'John Doe' as name, 33 as age").fetchone()
+ Person(name='John Doe', age=33, weight=None)
+
+
+.. index::
+ single: Row Maker
+ single: Row Factory
+
+.. _row-factory-create:
+
+Creating new row factories
+--------------------------
+
+A *row factory* is a callable that accepts a `Cursor` object and returns
+another callable, a *row maker*, which takes raw data (as a sequence of
+values) and returns the desired object.
+
+The role of the row factory is to inspect a query result (it is called after a
+query is executed and properties such as `~Cursor.description` and
+`~Cursor.pgresult` are available on the cursor) and to prepare a callable
+which is efficient to call repeatedly (because, for instance, the names of the
+columns are extracted, sanitised, and stored in local variables).
+
+Formally, these objects are represented by the `~psycopg.rows.RowFactory` and
+`~psycopg.rows.RowMaker` protocols.
+
+`~RowFactory` objects can be implemented as a class, for instance:
+
+.. code:: python
+
+ from typing import Any, Sequence
+ from psycopg import Cursor
+
+ class DictRowFactory:
+ def __init__(self, cursor: Cursor[Any]):
+ self.fields = [c.name for c in cursor.description]
+
+ def __call__(self, values: Sequence[Any]) -> dict[str, Any]:
+ return dict(zip(self.fields, values))
+
+or as a plain function:
+
+.. code:: python
+
+ def dict_row_factory(cursor: Cursor[Any]) -> RowMaker[dict[str, Any]]:
+ fields = [c.name for c in cursor.description]
+
+ def make_row(values: Sequence[Any]) -> dict[str, Any]:
+ return dict(zip(fields, values))
+
+ return make_row
+
+These can then be used by specifying a `row_factory` argument in
+`Connection.connect()`, `Connection.cursor()`, or by setting the
+`Connection.row_factory` attribute.
+
+.. code:: python
+
+ conn = psycopg.connect(row_factory=DictRowFactory)
+ cur = conn.execute("SELECT first_name, last_name, age FROM persons")
+ person = cur.fetchone()
+ print(f"{person['first_name']} {person['last_name']}")
diff --git a/docs/advanced/typing.rst b/docs/advanced/typing.rst
new file mode 100644
index 0000000..71b4e41
--- /dev/null
+++ b/docs/advanced/typing.rst
@@ -0,0 +1,180 @@
+.. currentmodule:: psycopg
+
+.. _static-typing:
+
+Static Typing
+=============
+
+Psycopg source code is annotated according to :pep:`0484` type hints and is
+checked using the current version of Mypy_ in ``--strict`` mode.
+
+If your application is checked using Mypy too you can make use of Psycopg
+types to validate the correct use of Psycopg objects and of the data returned
+by the database.
+
+.. _Mypy: http://mypy-lang.org/
+
+
+Generic types
+-------------
+
+Psycopg `Connection` and `Cursor` objects are `~typing.Generic` objects and
+support a `!Row` parameter which is the type of the records returned.
+
+By default methods such as `Cursor.fetchall()` return normal tuples of unknown
+size and content. As such, the `connect()` function returns an object of type
+`!psycopg.Connection[Tuple[Any, ...]]` and `Connection.cursor()` returns an
+object of type `!psycopg.Cursor[Tuple[Any, ...]]`. If you are writing generic
+plumbing code it might be practical to use annotations such as
+`!Connection[Any]` and `!Cursor[Any]`.
+
+.. code:: python
+
+ conn = psycopg.connect() # type is psycopg.Connection[Tuple[Any, ...]]
+
+ cur = conn.cursor() # type is psycopg.Cursor[Tuple[Any, ...]]
+
+ rec = cur.fetchone() # type is Optional[Tuple[Any, ...]]
+
+ recs = cur.fetchall() # type is List[Tuple[Any, ...]]
+
+
+.. _row-factory-static:
+
+Type of rows returned
+---------------------
+
+If you want to use connections and cursors returning your data as different
+types, for instance as dictionaries, you can use the `!row_factory` argument
+of the `~Connection.connect()` and the `~Connection.cursor()` method, which
+will control what type of record is returned by the fetch methods of the
+cursors and annotate the returned objects accordingly. See
+:ref:`row-factories` for more details.
+
+.. code:: python
+
+ dconn = psycopg.connect(row_factory=dict_row)
+ # dconn type is psycopg.Connection[Dict[str, Any]]
+
+ dcur = conn.cursor(row_factory=dict_row)
+ dcur = dconn.cursor()
+ # dcur type is psycopg.Cursor[Dict[str, Any]] in both cases
+
+ drec = dcur.fetchone()
+ # drec type is Optional[Dict[str, Any]]
+
+
+.. _example-pydantic:
+
+Example: returning records as Pydantic models
+---------------------------------------------
+
+Using Pydantic_ it is possible to enforce static typing at runtime. Using a
+Pydantic model factory the code can be checked statically using Mypy and
+querying the database will raise an exception if the rows returned is not
+compatible with the model.
+
+.. _Pydantic: https://pydantic-docs.helpmanual.io/
+
+The following example can be checked with ``mypy --strict`` without reporting
+any issue. Pydantic will also raise a runtime error in case the
+`!Person` is used with a query that returns incompatible data.
+
+.. code:: python
+
+ from datetime import date
+ from typing import Optional
+
+ import psycopg
+ from psycopg.rows import class_row
+ from pydantic import BaseModel
+
+ class Person(BaseModel):
+ id: int
+ first_name: str
+ last_name: str
+ dob: Optional[date]
+
+ def fetch_person(id: int) -> Person:
+ with psycopg.connect() as conn:
+ with conn.cursor(row_factory=class_row(Person)) as cur:
+ cur.execute(
+ """
+ SELECT id, first_name, last_name, dob
+ FROM (VALUES
+ (1, 'John', 'Doe', '2000-01-01'::date),
+ (2, 'Jane', 'White', NULL)
+ ) AS data (id, first_name, last_name, dob)
+ WHERE id = %(id)s;
+ """,
+ {"id": id},
+ )
+ obj = cur.fetchone()
+
+ # reveal_type(obj) would return 'Optional[Person]' here
+
+ if not obj:
+ raise KeyError(f"person {id} not found")
+
+ # reveal_type(obj) would return 'Person' here
+
+ return obj
+
+ for id in [1, 2]:
+ p = fetch_person(id)
+ if p.dob:
+ print(f"{p.first_name} was born in {p.dob.year}")
+ else:
+ print(f"Who knows when {p.first_name} was born")
+
+
+.. _literal-string:
+
+Checking literal strings in queries
+-----------------------------------
+
+The `~Cursor.execute()` method and similar should only receive a literal
+string as input, according to :pep:`675`. This means that the query should
+come from a literal string in your code, not from an arbitrary string
+expression.
+
+For instance, passing an argument to the query should be done via the second
+argument to `!execute()`, not by string composition:
+
+.. code:: python
+
+ def get_record(conn: psycopg.Connection[Any], id: int) -> Any:
+ cur = conn.execute("SELECT * FROM my_table WHERE id = %s" % id) # BAD!
+ return cur.fetchone()
+
+ # the function should be implemented as:
+
+ def get_record(conn: psycopg.Connection[Any], id: int) -> Any:
+ cur = conn.execute("select * FROM my_table WHERE id = %s", (id,))
+ return cur.fetchone()
+
+If you are composing a query dynamically you should use the `sql.SQL` object
+and similar to escape safely table and field names. The parameter of the
+`!SQL()` object should be a literal string:
+
+.. code:: python
+
+ def count_records(conn: psycopg.Connection[Any], table: str) -> int:
+ query = "SELECT count(*) FROM %s" % table # BAD!
+ return conn.execute(query).fetchone()[0]
+
+ # the function should be implemented as:
+
+ def count_records(conn: psycopg.Connection[Any], table: str) -> int:
+ query = sql.SQL("SELECT count(*) FROM {}").format(sql.Identifier(table))
+ return conn.execute(query).fetchone()[0]
+
+At the time of writing, no Python static analyzer implements this check (`mypy
+doesn't implement it`__, Pyre_ does, but `doesn't work with psycopg yet`__).
+Once the type checkers support will be complete, the above bad statements
+should be reported as errors.
+
+.. __: https://github.com/python/mypy/issues/12554
+.. __: https://github.com/facebook/pyre-check/issues/636
+
+.. _Pyre: https://pyre-check.org/
diff --git a/docs/api/abc.rst b/docs/api/abc.rst
new file mode 100644
index 0000000..9514e9b
--- /dev/null
+++ b/docs/api/abc.rst
@@ -0,0 +1,75 @@
+`!abc` -- Psycopg abstract classes
+==================================
+
+The module exposes Psycopg definitions which can be used for static type
+checking.
+
+.. module:: psycopg.abc
+
+.. autoclass:: Dumper(cls, context=None)
+
+ :param cls: The type that will be managed by this dumper.
+ :type cls: type
+ :param context: The context where the transformation is performed. If not
+ specified the conversion might be inaccurate, for instance it will not
+ be possible to know the connection encoding or the server date format.
+ :type context: `AdaptContext` or None
+
+ A partial implementation of this protocol (implementing everything except
+ `dump()`) is available as `psycopg.adapt.Dumper`.
+
+ .. autoattribute:: format
+
+ .. automethod:: dump
+
+ The format returned by dump shouldn't contain quotes or escaped
+ values.
+
+ .. automethod:: quote
+
+ .. tip::
+
+ This method will be used by `~psycopg.sql.Literal` to convert a
+ value client-side.
+
+ This method only makes sense for text dumpers; the result of calling
+ it on a binary dumper is undefined. It might scratch your car, or burn
+ your cake. Don't tell me I didn't warn you.
+
+ .. autoattribute:: oid
+
+ If the OID is not specified, PostgreSQL will try to infer the type
+ from the context, but this may fail in some contexts and may require a
+ cast (e.g. specifying :samp:`%s::{type}` for its placeholder).
+
+ You can use the `psycopg.adapters`\ ``.``\
+ `~psycopg.adapt.AdaptersMap.types` registry to find the OID of builtin
+ types, and you can use `~psycopg.types.TypeInfo` to extend the
+ registry to custom types.
+
+ .. automethod:: get_key
+ .. automethod:: upgrade
+
+
+.. autoclass:: Loader(oid, context=None)
+
+ :param oid: The type that will be managed by this dumper.
+ :type oid: int
+ :param context: The context where the transformation is performed. If not
+ specified the conversion might be inaccurate, for instance it will not
+ be possible to know the connection encoding or the server date format.
+ :type context: `AdaptContext` or None
+
+ A partial implementation of this protocol (implementing everything except
+ `load()`) is available as `psycopg.adapt.Loader`.
+
+ .. autoattribute:: format
+
+ .. automethod:: load
+
+
+.. autoclass:: AdaptContext
+ :members:
+
+ .. seealso:: :ref:`adaptation` for an explanation about how contexts are
+ connected.
diff --git a/docs/api/adapt.rst b/docs/api/adapt.rst
new file mode 100644
index 0000000..e47816c
--- /dev/null
+++ b/docs/api/adapt.rst
@@ -0,0 +1,91 @@
+`adapt` -- Types adaptation
+===========================
+
+.. module:: psycopg.adapt
+
+The `!psycopg.adapt` module exposes a set of objects useful for the
+configuration of *data adaptation*, which is the conversion of Python objects
+to PostgreSQL data types and back.
+
+These objects are useful if you need to configure data adaptation, i.e.
+if you need to change the default way that Psycopg converts between types or
+if you want to adapt custom data types and objects. You don't need this object
+in the normal use of Psycopg.
+
+See :ref:`adaptation` for an overview of the Psycopg adaptation system.
+
+.. _abstract base class: https://docs.python.org/glossary.html#term-abstract-base-class
+
+
+Dumpers and loaders
+-------------------
+
+.. autoclass:: Dumper(cls, context=None)
+
+ This is an `abstract base class`_, partially implementing the
+ `~psycopg.abc.Dumper` protocol. Subclasses *must* at least implement the
+ `.dump()` method and optionally override other members.
+
+ .. automethod:: dump
+
+ .. attribute:: format
+ :type: psycopg.pq.Format
+ :value: TEXT
+
+ Class attribute. Set it to `~psycopg.pq.Format.BINARY` if the class
+ `dump()` methods converts the object to binary format.
+
+ .. automethod:: quote
+
+ .. automethod:: get_key
+
+ .. automethod:: upgrade
+
+
+.. autoclass:: Loader(oid, context=None)
+
+ This is an `abstract base class`_, partially implementing the
+ `~psycopg.abc.Loader` protocol. Subclasses *must* at least implement the
+ `.load()` method and optionally override other members.
+
+ .. automethod:: load
+
+ .. attribute:: format
+ :type: psycopg.pq.Format
+ :value: TEXT
+
+ Class attribute. Set it to `~psycopg.pq.Format.BINARY` if the class
+ `load()` methods converts the object from binary format.
+
+
+Other objects used in adaptations
+---------------------------------
+
+.. autoclass:: PyFormat
+ :members:
+
+
+.. autoclass:: AdaptersMap
+
+ .. seealso:: :ref:`adaptation` for an explanation about how contexts are
+ connected.
+
+ .. automethod:: register_dumper
+ .. automethod:: register_loader
+
+ .. attribute:: types
+
+ The object where to look up for types information (such as the mapping
+ between type names and oids in the specified context).
+
+ :type: `~psycopg.types.TypesRegistry`
+
+ .. automethod:: get_dumper
+ .. automethod:: get_dumper_by_oid
+ .. automethod:: get_loader
+
+
+.. autoclass:: Transformer(context=None)
+
+ :param context: The context where the transformer should operate.
+ :type context: `~psycopg.abc.AdaptContext`
diff --git a/docs/api/connections.rst b/docs/api/connections.rst
new file mode 100644
index 0000000..db25382
--- /dev/null
+++ b/docs/api/connections.rst
@@ -0,0 +1,489 @@
+.. currentmodule:: psycopg
+
+Connection classes
+==================
+
+The `Connection` and `AsyncConnection` classes are the main wrappers for a
+PostgreSQL database session. You can imagine them similar to a :program:`psql`
+session.
+
+One of the differences compared to :program:`psql` is that a `Connection`
+usually handles a transaction automatically: other sessions will not be able
+to see the changes until you have committed them, more or less explicitly.
+Take a look to :ref:`transactions` for the details.
+
+
+The `!Connection` class
+-----------------------
+
+.. autoclass:: Connection()
+
+ This class implements a `DBAPI-compliant interface`__. It is what you want
+ to use if you write a "classic", blocking program (eventually using
+ threads or Eventlet/gevent for concurrency). If your program uses `asyncio`
+ you might want to use `AsyncConnection` instead.
+
+ .. __: https://www.python.org/dev/peps/pep-0249/#connection-objects
+
+ Connections behave as context managers: on block exit, the current
+ transaction will be committed (or rolled back, in case of exception) and
+ the connection will be closed.
+
+ .. automethod:: connect
+
+ :param conninfo: The `connection string`__ (a ``postgresql://`` url or
+ a list of ``key=value`` pairs) to specify where and how to connect.
+ :param kwargs: Further parameters specifying the connection string.
+ They override the ones specified in `!conninfo`.
+ :param autocommit: If `!True` don't start transactions automatically.
+ See :ref:`transactions` for details.
+ :param row_factory: The row factory specifying what type of records
+ to create fetching data (default: `~psycopg.rows.tuple_row()`). See
+ :ref:`row-factories` for details.
+ :param cursor_factory: Initial value for the `cursor_factory` attribute
+ of the connection (new in Psycopg 3.1).
+ :param prepare_threshold: Initial value for the `prepare_threshold`
+ attribute of the connection (new in Psycopg 3.1).
+
+ More specialized use:
+
+ :param context: A context to copy the initial adapters configuration
+ from. It might be an `~psycopg.adapt.AdaptersMap` with customized
+ loaders and dumpers, used as a template to create several connections.
+ See :ref:`adaptation` for further details.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-CONNSTRING
+
+ This method is also aliased as `psycopg.connect()`.
+
+ .. seealso::
+
+ - the list of `the accepted connection parameters`__
+ - the `environment variables`__ affecting connection
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
+ .. __: https://www.postgresql.org/docs/current/libpq-envars.html
+
+ .. versionchanged:: 3.1
+ added `!prepare_threshold` and `!cursor_factory` parameters.
+
+ .. automethod:: close
+
+ .. note::
+
+ You can use::
+
+ with psycopg.connect() as conn:
+ ...
+
+ to close the connection automatically when the block is exited.
+ See :ref:`with-connection`.
+
+ .. autoattribute:: closed
+ .. autoattribute:: broken
+
+ .. method:: cursor(*, binary: bool = False, \
+ row_factory: Optional[RowFactory] = None) -> Cursor
+ .. method:: cursor(name: str, *, binary: bool = False, \
+ row_factory: Optional[RowFactory] = None, \
+ scrollable: Optional[bool] = None, withhold: bool = False) -> ServerCursor
+ :noindex:
+
+ Return a new cursor to send commands and queries to the connection.
+
+ :param name: If not specified create a client-side cursor, if
+ specified create a server-side cursor. See
+ :ref:`cursor-types` for details.
+ :param binary: If `!True` return binary values from the database. All
+ the types returned by the query must have a binary
+ loader. See :ref:`binary-data` for details.
+ :param row_factory: If specified override the `row_factory` set on the
+ connection. See :ref:`row-factories` for details.
+ :param scrollable: Specify the `~ServerCursor.scrollable` property of
+ the server-side cursor created.
+ :param withhold: Specify the `~ServerCursor.withhold` property of
+ the server-side cursor created.
+ :return: A cursor of the class specified by `cursor_factory` (or
+ `server_cursor_factory` if `!name` is specified).
+
+ .. note::
+
+ You can use::
+
+ with conn.cursor() as cur:
+ ...
+
+ to close the cursor automatically when the block is exited.
+
+ .. autoattribute:: cursor_factory
+
+ The type, or factory function, returned by `cursor()` and `execute()`.
+
+ Default is `psycopg.Cursor`.
+
+ .. autoattribute:: server_cursor_factory
+
+ The type, or factory function, returned by `cursor()` when a name is
+ specified.
+
+ Default is `psycopg.ServerCursor`.
+
+ .. autoattribute:: row_factory
+
+ The row factory defining the type of rows returned by
+ `~Cursor.fetchone()` and the other cursor fetch methods.
+
+ The default is `~psycopg.rows.tuple_row`, which means that the fetch
+ methods will return simple tuples.
+
+ .. seealso:: See :ref:`row-factories` for details about defining the
+ objects returned by cursors.
+
+ .. automethod:: execute
+
+ :param query: The query to execute.
+ :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed`
+ :param params: The parameters to pass to the query, if any.
+ :type params: Sequence or Mapping
+ :param prepare: Force (`!True`) or disallow (`!False`) preparation of
+ the query. By default (`!None`) prepare automatically. See
+ :ref:`prepared-statements`.
+ :param binary: If `!True` the cursor will return binary values from the
+ database. All the types returned by the query must have a binary
+ loader. See :ref:`binary-data` for details.
+
+ The method simply creates a `Cursor` instance, `~Cursor.execute()` the
+ query requested, and returns it.
+
+ See :ref:`query-parameters` for all the details about executing
+ queries.
+
+ .. automethod:: pipeline
+
+ The method is a context manager: you should call it using::
+
+ with conn.pipeline() as p:
+ ...
+
+ At the end of the block, a synchronization point is established and
+ the connection returns in normal mode.
+
+ You can call the method recursively from within a pipeline block.
+ Innermost blocks will establish a synchronization point on exit, but
+ pipeline mode will be kept until the outermost block exits.
+
+ See :ref:`pipeline-mode` for details.
+
+ .. versionadded:: 3.1
+
+
+ .. rubric:: Transaction management methods
+
+ For details see :ref:`transactions`.
+
+ .. automethod:: commit
+ .. automethod:: rollback
+ .. automethod:: transaction
+
+ .. note::
+
+ The method must be called with a syntax such as::
+
+ with conn.transaction():
+ ...
+
+ with conn.transaction() as tx:
+ ...
+
+ The latter is useful if you need to interact with the
+ `Transaction` object. See :ref:`transaction-context` for details.
+
+ Inside a transaction block it will not be possible to call `commit()`
+ or `rollback()`.
+
+ .. autoattribute:: autocommit
+
+ The property is writable for sync connections, read-only for async
+ ones: you should call `!await` `~AsyncConnection.set_autocommit`
+ :samp:`({value})` instead.
+
+ The following three properties control the characteristics of new
+ transactions. See :ref:`transaction-characteristics` for details.
+
+ .. autoattribute:: isolation_level
+
+ `!None` means use the default set in the default_transaction_isolation__
+ configuration parameter of the server.
+
+ .. __: https://www.postgresql.org/docs/current/runtime-config-client.html
+ #GUC-DEFAULT-TRANSACTION-ISOLATION
+
+ .. autoattribute:: read_only
+
+ `!None` means use the default set in the default_transaction_read_only__
+ configuration parameter of the server.
+
+ .. __: https://www.postgresql.org/docs/current/runtime-config-client.html
+ #GUC-DEFAULT-TRANSACTION-READ-ONLY
+
+ .. autoattribute:: deferrable
+
+ `!None` means use the default set in the default_transaction_deferrable__
+ configuration parameter of the server.
+
+ .. __: https://www.postgresql.org/docs/current/runtime-config-client.html
+ #GUC-DEFAULT-TRANSACTION-DEFERRABLE
+
+
+ .. rubric:: Checking and configuring the connection state
+
+ .. attribute:: pgconn
+ :type: psycopg.pq.PGconn
+
+ The `~pq.PGconn` libpq connection wrapper underlying the `!Connection`.
+
+ It can be used to send low level commands to PostgreSQL and access
+ features not currently wrapped by Psycopg.
+
+ .. autoattribute:: info
+
+ .. autoattribute:: prepare_threshold
+
+ See :ref:`prepared-statements` for details.
+
+
+ .. autoattribute:: prepared_max
+
+ If more queries need to be prepared, old ones are deallocated__.
+
+ .. __: https://www.postgresql.org/docs/current/sql-deallocate.html
+
+
+ .. rubric:: Methods you can use to do something cool
+
+ .. automethod:: cancel
+
+ .. automethod:: notifies
+
+ Notifies are received after using :sql:`LISTEN` in a connection, when
+ any sessions in the database generates a :sql:`NOTIFY` on one of the
+ listened channels.
+
+ .. automethod:: add_notify_handler
+
+ See :ref:`async-notify` for details.
+
+ .. automethod:: remove_notify_handler
+
+ .. automethod:: add_notice_handler
+
+ See :ref:`async-messages` for details.
+
+ .. automethod:: remove_notice_handler
+
+ .. automethod:: fileno
+
+
+ .. _tpc-methods:
+
+ .. rubric:: Two-Phase Commit support methods
+
+ .. versionadded:: 3.1
+
+ .. seealso:: :ref:`two-phase-commit` for an introductory explanation of
+ these methods.
+
+ .. automethod:: xid
+
+ .. automethod:: tpc_begin
+
+ :param xid: The id of the transaction
+ :type xid: Xid or str
+
+ This method should be called outside of a transaction (i.e. nothing
+ may have executed since the last `commit()` or `rollback()` and
+ `~ConnectionInfo.transaction_status` is `~pq.TransactionStatus.IDLE`).
+
+ Furthermore, it is an error to call `!commit()` or `!rollback()`
+ within the TPC transaction: in this case a `ProgrammingError`
+ is raised.
+
+ The `!xid` may be either an object returned by the `xid()` method or a
+ plain string: the latter allows to create a transaction using the
+ provided string as PostgreSQL transaction id. See also
+ `tpc_recover()`.
+
+
+ .. automethod:: tpc_prepare
+
+ A `ProgrammingError` is raised if this method is used outside of a TPC
+ transaction.
+
+ After calling `!tpc_prepare()`, no statements can be executed until
+ `tpc_commit()` or `tpc_rollback()` will be
+ called.
+
+ .. seealso:: The |PREPARE TRANSACTION|_ PostgreSQL command.
+
+ .. |PREPARE TRANSACTION| replace:: :sql:`PREPARE TRANSACTION`
+ .. _PREPARE TRANSACTION: https://www.postgresql.org/docs/current/static/sql-prepare-transaction.html
+
+
+ .. automethod:: tpc_commit
+
+ :param xid: The id of the transaction
+ :type xid: Xid or str
+
+ When called with no arguments, `!tpc_commit()` commits a TPC
+ transaction previously prepared with `tpc_prepare()`.
+
+ If `!tpc_commit()` is called prior to `!tpc_prepare()`, a single phase
+ commit is performed. A transaction manager may choose to do this if
+ only a single resource is participating in the global transaction.
+
+ When called with a transaction ID `!xid`, the database commits the
+ given transaction. If an invalid transaction ID is provided, a
+ `ProgrammingError` will be raised. This form should be called outside
+ of a transaction, and is intended for use in recovery.
+
+ On return, the TPC transaction is ended.
+
+ .. seealso:: The |COMMIT PREPARED|_ PostgreSQL command.
+
+ .. |COMMIT PREPARED| replace:: :sql:`COMMIT PREPARED`
+ .. _COMMIT PREPARED: https://www.postgresql.org/docs/current/static/sql-commit-prepared.html
+
+
+ .. automethod:: tpc_rollback
+
+ :param xid: The id of the transaction
+ :type xid: Xid or str
+
+ When called with no arguments, `!tpc_rollback()` rolls back a TPC
+ transaction. It may be called before or after `tpc_prepare()`.
+
+ When called with a transaction ID `!xid`, it rolls back the given
+ transaction. If an invalid transaction ID is provided, a
+ `ProgrammingError` is raised. This form should be called outside of a
+ transaction, and is intended for use in recovery.
+
+ On return, the TPC transaction is ended.
+
+ .. seealso:: The |ROLLBACK PREPARED|_ PostgreSQL command.
+
+ .. |ROLLBACK PREPARED| replace:: :sql:`ROLLBACK PREPARED`
+ .. _ROLLBACK PREPARED: https://www.postgresql.org/docs/current/static/sql-rollback-prepared.html
+
+
+ .. automethod:: tpc_recover
+
+ Returns a list of `Xid` representing pending transactions, suitable
+ for use with `tpc_commit()` or `tpc_rollback()`.
+
+ If a transaction was not initiated by Psycopg, the returned Xids will
+ have attributes `~Xid.format_id` and `~Xid.bqual` set to `!None` and
+ the `~Xid.gtrid` set to the PostgreSQL transaction ID: such Xids are
+ still usable for recovery. Psycopg uses the same algorithm of the
+ `PostgreSQL JDBC driver`__ to encode a XA triple in a string, so
+ transactions initiated by a program using such driver should be
+ unpacked correctly.
+
+ .. __: https://jdbc.postgresql.org/
+
+ Xids returned by `!tpc_recover()` also have extra attributes
+ `~Xid.prepared`, `~Xid.owner`, `~Xid.database` populated with the
+ values read from the server.
+
+ .. seealso:: the |pg_prepared_xacts|_ system view.
+
+ .. |pg_prepared_xacts| replace:: `pg_prepared_xacts`
+ .. _pg_prepared_xacts: https://www.postgresql.org/docs/current/static/view-pg-prepared-xacts.html
+
+
+The `!AsyncConnection` class
+----------------------------
+
+.. autoclass:: AsyncConnection()
+
+ This class implements a DBAPI-inspired interface, with all the blocking
+ methods implemented as coroutines. Unless specified otherwise,
+ non-blocking methods are shared with the `Connection` class.
+
+ The following methods have the same behaviour of the matching `!Connection`
+ methods, but should be called using the `await` keyword.
+
+ .. automethod:: connect
+
+ .. versionchanged:: 3.1
+
+ Automatically resolve domain names asynchronously. In previous
+ versions, name resolution blocks, unless the `!hostaddr`
+ parameter is specified, or the `~psycopg._dns.resolve_hostaddr_async()`
+ function is used.
+
+ .. automethod:: close
+
+ .. note:: You can use ``async with`` to close the connection
+ automatically when the block is exited, but be careful about
+ the async quirkness: see :ref:`async-with` for details.
+
+ .. method:: cursor(*, binary: bool = False, \
+ row_factory: Optional[RowFactory] = None) -> AsyncCursor
+ .. method:: cursor(name: str, *, binary: bool = False, \
+ row_factory: Optional[RowFactory] = None, \
+ scrollable: Optional[bool] = None, withhold: bool = False) -> AsyncServerCursor
+ :noindex:
+
+ .. note::
+
+ You can use::
+
+ async with conn.cursor() as cur:
+ ...
+
+ to close the cursor automatically when the block is exited.
+
+ .. autoattribute:: cursor_factory
+
+ Default is `psycopg.AsyncCursor`.
+
+ .. autoattribute:: server_cursor_factory
+
+ Default is `psycopg.AsyncServerCursor`.
+
+ .. autoattribute:: row_factory
+
+ .. automethod:: execute
+
+ .. automethod:: pipeline
+
+ .. note::
+
+ It must be called as::
+
+ async with conn.pipeline() as p:
+ ...
+
+ .. automethod:: commit
+ .. automethod:: rollback
+
+ .. automethod:: transaction
+
+ .. note::
+
+ It must be called as::
+
+ async with conn.transaction() as tx:
+ ...
+
+ .. automethod:: notifies
+ .. automethod:: set_autocommit
+ .. automethod:: set_isolation_level
+ .. automethod:: set_read_only
+ .. automethod:: set_deferrable
+
+ .. automethod:: tpc_prepare
+ .. automethod:: tpc_commit
+ .. automethod:: tpc_rollback
+ .. automethod:: tpc_recover
diff --git a/docs/api/conninfo.rst b/docs/api/conninfo.rst
new file mode 100644
index 0000000..9e5b01d
--- /dev/null
+++ b/docs/api/conninfo.rst
@@ -0,0 +1,24 @@
+.. _psycopg.conninfo:
+
+`conninfo` -- manipulate connection strings
+===========================================
+
+This module contains a few utility functions to manipulate database
+connection strings.
+
+.. module:: psycopg.conninfo
+
+.. autofunction:: conninfo_to_dict
+
+ .. code:: python
+
+ >>> conninfo_to_dict("postgres://jeff@example.com/db", user="piro")
+ {'user': 'piro', 'dbname': 'db', 'host': 'example.com'}
+
+
+.. autofunction:: make_conninfo
+
+ .. code:: python
+
+ >>> make_conninfo("dbname=db user=jeff", user="piro", port=5432)
+ 'dbname=db user=piro port=5432'
diff --git a/docs/api/copy.rst b/docs/api/copy.rst
new file mode 100644
index 0000000..81a96e2
--- /dev/null
+++ b/docs/api/copy.rst
@@ -0,0 +1,117 @@
+.. currentmodule:: psycopg
+
+COPY-related objects
+====================
+
+The main objects (`Copy`, `AsyncCopy`) present the main interface to exchange
+data during a COPY operations. These objects are normally obtained by the
+methods `Cursor.copy()` and `AsyncCursor.copy()`; however, they can be also
+created directly, for instance to write to a destination which is not a
+database (e.g. using a `~psycopg.copy.FileWriter`).
+
+See :ref:`copy` for details.
+
+
+Main Copy objects
+-----------------
+
+.. autoclass:: Copy()
+
+ The object is normally returned by `!with` `Cursor.copy()`.
+
+ .. automethod:: write_row
+
+ The data in the tuple will be converted as configured on the cursor;
+ see :ref:`adaptation` for details.
+
+ .. automethod:: write
+ .. automethod:: read
+
+ Instead of using `!read()` you can iterate on the `!Copy` object to
+ read its data row by row, using ``for row in copy: ...``.
+
+ .. automethod:: rows
+
+ Equivalent of iterating on `read_row()` until it returns `!None`
+
+ .. automethod:: read_row
+ .. automethod:: set_types
+
+
+.. autoclass:: AsyncCopy()
+
+ The object is normally returned by ``async with`` `AsyncCursor.copy()`.
+ Its methods are similar to the ones of the `Copy` object but offering an
+ `asyncio` interface (`await`, `async for`, `async with`).
+
+ .. automethod:: write_row
+ .. automethod:: write
+ .. automethod:: read
+
+ Instead of using `!read()` you can iterate on the `!AsyncCopy` object
+ to read its data row by row, using ``async for row in copy: ...``.
+
+ .. automethod:: rows
+
+ Use it as `async for record in copy.rows():` ...
+
+ .. automethod:: read_row
+
+
+.. _copy-writers:
+
+Writer objects
+--------------
+
+.. currentmodule:: psycopg.copy
+
+.. versionadded:: 3.1
+
+Copy writers are helper objects to specify where to write COPY-formatted data.
+By default, data is written to the database (using the `LibpqWriter`). It is
+possible to write copy-data for offline use by using a `FileWriter`, or to
+customize further writing by implementing your own `Writer` or `AsyncWriter`
+subclass.
+
+Writers instances can be used passing them to the cursor
+`~psycopg.Cursor.copy()` method or to the `~psycopg.Copy` constructor, as the
+`!writer` argument.
+
+.. autoclass:: Writer
+
+ This is an abstract base class: subclasses are required to implement their
+ `write()` method.
+
+ .. automethod:: write
+ .. automethod:: finish
+
+
+.. autoclass:: LibpqWriter
+
+ This is the writer used by default if none is specified.
+
+
+.. autoclass:: FileWriter
+
+ This writer should be used without executing a :sql:`COPY` operation on
+ the database. For example, if `records` is a list of tuples containing
+ data to save in COPY format to a file (e.g. for later import), it can be
+ used as:
+
+ .. code:: python
+
+ with open("target-file.pgcopy", "wb") as f:
+ with Copy(cur, writer=FileWriter(f)) as copy:
+ for record in records
+ copy.write_row(record)
+
+
+.. autoclass:: AsyncWriter
+
+ This class methods have the same semantics of the ones of `Writer`, but
+ offer an async interface.
+
+ .. automethod:: write
+ .. automethod:: finish
+
+.. autoclass:: AsyncLibpqWriter
diff --git a/docs/api/crdb.rst b/docs/api/crdb.rst
new file mode 100644
index 0000000..de8344e
--- /dev/null
+++ b/docs/api/crdb.rst
@@ -0,0 +1,120 @@
+`crdb` -- CockroachDB support
+=============================
+
+.. module:: psycopg.crdb
+
+.. versionadded:: 3.1
+
+CockroachDB_ is a distributed database using the same fronted-backend protocol
+of PostgreSQL. As such, Psycopg can be used to write Python programs
+interacting with CockroachDB.
+
+.. _CockroachDB: https://www.cockroachlabs.com/
+
+Opening a connection to a CRDB database using `psycopg.connect()` provides a
+largely working object. However, using the `psycopg.crdb.connect()` function
+instead, Psycopg will create more specialised objects and provide a types
+mapping tweaked on the CockroachDB data model.
+
+
+.. _crdb-differences:
+
+Main differences from PostgreSQL
+--------------------------------
+
+CockroachDB behaviour is `different from PostgreSQL`__: please refer to the
+database documentation for details. These are some of the main differences
+affecting Psycopg behaviour:
+
+.. __: https://www.cockroachlabs.com/docs/stable/postgresql-compatibility.html
+
+- `~psycopg.Connection.cancel()` doesn't work before CockroachDB 22.1. On
+ older versions, you can use `CANCEL QUERY`_ instead (but from a different
+ connection).
+
+- :ref:`server-side-cursors` are well supported only from CockroachDB 22.1.3.
+
+- `~psycopg.ConnectionInfo.backend_pid` is only populated from CockroachDB
+ 22.1. Note however that you cannot use the PID to terminate the session; use
+ `SHOW session_id`_ to find the id of a session, which you may terminate with
+ `CANCEL SESSION`_ in lieu of PostgreSQL's :sql:`pg_terminate_backend()`.
+
+- Several data types are missing or slightly different from PostgreSQL (see
+ `adapters` for an overview of the differences).
+
+- The :ref:`two-phase commit protocol <two-phase-commit>` is not supported.
+
+- :sql:`LISTEN` and :sql:`NOTIFY` are not supported. However the `CHANGEFEED`_
+ command, in conjunction with `~psycopg.Cursor.stream()`, can provide push
+ notifications.
+
+.. _CANCEL QUERY: https://www.cockroachlabs.com/docs/stable/cancel-query.html
+.. _SHOW session_id: https://www.cockroachlabs.com/docs/stable/show-vars.html
+.. _CANCEL SESSION: https://www.cockroachlabs.com/docs/stable/cancel-session.html
+.. _CHANGEFEED: https://www.cockroachlabs.com/docs/stable/changefeed-for.html
+
+
+.. _crdb-objects:
+
+CockroachDB-specific objects
+----------------------------
+
+.. autofunction:: connect
+
+ This is an alias of the class method `CrdbConnection.connect`.
+
+ If you need an asynchronous connection use the `AsyncCrdbConnection.connect()`
+ method instead.
+
+
+.. autoclass:: CrdbConnection
+
+ `psycopg.Connection` subclass.
+
+ .. automethod:: is_crdb
+
+ :param conn: the connection to check
+ :type conn: `~psycopg.Connection`, `~psycopg.AsyncConnection`, `~psycopg.pq.PGconn`
+
+
+.. autoclass:: AsyncCrdbConnection
+
+ `psycopg.AsyncConnection` subclass.
+
+
+.. autoclass:: CrdbConnectionInfo
+
+ The object is returned by the `~psycopg.Connection.info` attribute of
+ `CrdbConnection` and `AsyncCrdbConnection`.
+
+ The object behaves like `!ConnectionInfo`, with the following differences:
+
+ .. autoattribute:: vendor
+
+ The `CockroachDB` string.
+
+ .. autoattribute:: server_version
+
+
+.. data:: adapters
+
+ The default adapters map establishing how Python and CockroachDB types are
+ converted into each other.
+
+ The map is used as a template when new connections are created, using
+ `psycopg.crdb.connect()` (similarly to the way `psycopg.adapters` is used
+ as template for new PostgreSQL connections).
+
+ This registry contains only the types and adapters supported by
+ CockroachDB. Several PostgreSQL types and adapters are missing or
+ different from PostgreSQL, among which:
+
+ - Composite types
+ - :sql:`range`, :sql:`multirange` types
+ - The :sql:`hstore` type
+ - Geometric types
+ - Nested arrays
+ - Arrays of :sql:`jsonb`
+ - The :sql:`cidr` data type
+ - The :sql:`json` type is an alias for :sql:`jsonb`
+ - The :sql:`int` type is an alias for :sql:`int8`, not `int4`.
diff --git a/docs/api/cursors.rst b/docs/api/cursors.rst
new file mode 100644
index 0000000..9c5b478
--- /dev/null
+++ b/docs/api/cursors.rst
@@ -0,0 +1,517 @@
+.. currentmodule:: psycopg
+
+Cursor classes
+==============
+
+The `Cursor` and `AsyncCursor` classes are the main objects to send commands
+to a PostgreSQL database session. They are normally created by the
+connection's `~Connection.cursor()` method.
+
+Using the `!name` parameter on `!cursor()` will create a `ServerCursor` or
+`AsyncServerCursor`, which can be used to retrieve partial results from a
+database.
+
+A `Connection` can create several cursors, but only one at time can perform
+operations, so they are not the best way to achieve parallelism (you may want
+to operate with several connections instead). All the cursors on the same
+connection have a view of the same session, so they can see each other's
+uncommitted data.
+
+
+The `!Cursor` class
+-------------------
+
+.. autoclass:: Cursor
+
+ This class implements a `DBAPI-compliant interface`__. It is what the
+ classic `Connection.cursor()` method returns. `AsyncConnection.cursor()`
+ will create instead `AsyncCursor` objects, which have the same set of
+ method but expose an `asyncio` interface and require `!async` and
+ `!await` keywords to operate.
+
+ .. __: dbapi-cursor_
+ .. _dbapi-cursor: https://www.python.org/dev/peps/pep-0249/#cursor-objects
+
+
+ Cursors behave as context managers: on block exit they are closed and
+ further operation will not be possible. Closing a cursor will not
+ terminate a transaction or a session though.
+
+ .. attribute:: connection
+ :type: Connection
+
+ The connection this cursor is using.
+
+ .. automethod:: close
+
+ .. note::
+
+ You can use::
+
+ with conn.cursor() as cur:
+ ...
+
+ to close the cursor automatically when the block is exited. See
+ :ref:`usage`.
+
+ .. autoattribute:: closed
+
+ .. rubric:: Methods to send commands
+
+ .. automethod:: execute
+
+ :param query: The query to execute.
+ :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed`
+ :param params: The parameters to pass to the query, if any.
+ :type params: Sequence or Mapping
+ :param prepare: Force (`!True`) or disallow (`!False`) preparation of
+ the query. By default (`!None`) prepare automatically. See
+ :ref:`prepared-statements`.
+ :param binary: Specify whether the server should return data in binary
+ format (`!True`) or in text format (`!False`). By default
+ (`!None`) return data as requested by the cursor's `~Cursor.format`.
+
+ Return the cursor itself, so that it will be possible to chain a fetch
+ operation after the call.
+
+ See :ref:`query-parameters` for all the details about executing
+ queries.
+
+ .. versionchanged:: 3.1
+
+ The `query` argument must be a `~typing.StringLiteral`. If you
+ need to compose a query dynamically, please use `sql.SQL` and
+ related objects.
+
+ See :pep:`675` for details.
+
+ .. automethod:: executemany
+
+ :param query: The query to execute
+ :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed`
+ :param params_seq: The parameters to pass to the query
+ :type params_seq: Sequence of Sequences or Mappings
+ :param returning: If `!True`, fetch the results of the queries executed
+ :type returning: `!bool`
+
+ This is more efficient than performing separate queries, but in case of
+ several :sql:`INSERT` (and with some SQL creativity for massive
+ :sql:`UPDATE` too) you may consider using `copy()`.
+
+ If the queries return data you want to read (e.g. when executing an
+ :sql:`INSERT ... RETURNING` or a :sql:`SELECT` with a side-effect),
+ you can specify `!returning=True`; the results will be available in
+ the cursor's state and can be read using `fetchone()` and similar
+ methods. Each input parameter will produce a separate result set: use
+ `nextset()` to read the results of the queries after the first one.
+
+ See :ref:`query-parameters` for all the details about executing
+ queries.
+
+ .. versionchanged:: 3.1
+
+ - Added `!returning` parameter to receive query results.
+ - Performance optimised by making use of the pipeline mode, when
+ using libpq 14 or newer.
+
+ .. automethod:: copy
+
+ :param statement: The copy operation to execute
+ :type statement: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed`
+ :param params: The parameters to pass to the statement, if any.
+ :type params: Sequence or Mapping
+
+ .. note::
+
+ The method must be called with::
+
+ with cursor.copy() as copy:
+ ...
+
+ See :ref:`copy` for information about :sql:`COPY`.
+
+ .. versionchanged:: 3.1
+ Added parameters support.
+
+ .. automethod:: stream
+
+ This command is similar to execute + iter; however it supports endless
+ data streams. The feature is not available in PostgreSQL, but some
+ implementations exist: Materialize `TAIL`__ and CockroachDB
+ `CHANGEFEED`__ for instance.
+
+ The feature, and the API supporting it, are still experimental.
+ Beware... 👀
+
+ .. __: https://materialize.com/docs/sql/tail/#main
+ .. __: https://www.cockroachlabs.com/docs/stable/changefeed-for.html
+
+ The parameters are the same of `execute()`.
+
+ .. warning::
+
+ Failing to consume the iterator entirely will result in a
+ connection left in `~psycopg.ConnectionInfo.transaction_status`
+ `~pq.TransactionStatus.ACTIVE` state: this connection will refuse
+ to receive further commands (with a message such as *another
+ command is already in progress*).
+
+ If there is a chance that the generator is not consumed entirely,
+ in order to restore the connection to a working state you can call
+ `~generator.close` on the generator object returned by `!stream()`. The
+ `contextlib.closing` function might be particularly useful to make
+ sure that `!close()` is called:
+
+ .. code::
+
+ with closing(cur.stream("select generate_series(1, 10000)")) as gen:
+ for rec in gen:
+ something(rec) # might fail
+
+ Without calling `!close()`, in case of error, the connection will
+ be `!ACTIVE` and unusable. If `!close()` is called, the connection
+ might be `!INTRANS` or `!INERROR`, depending on whether the server
+ managed to send the entire resultset to the client. An autocommit
+ connection will be `!IDLE` instead.
+
+
+ .. attribute:: format
+
+ The format of the data returned by the queries. It can be selected
+ initially e.g. specifying `Connection.cursor`\ `!(binary=True)` and
+ changed during the cursor's lifetime. It is also possible to override
+ the value for single queries, e.g. specifying `execute`\
+ `!(binary=True)`.
+
+ :type: `pq.Format`
+ :default: `~pq.Format.TEXT`
+
+ .. seealso:: :ref:`binary-data`
+
+
+ .. rubric:: Methods to retrieve results
+
+ Fetch methods are only available if the last operation produced results,
+ e.g. a :sql:`SELECT` or a command with :sql:`RETURNING`. They will raise
+ an exception if used with operations that don't return result, such as an
+ :sql:`INSERT` with no :sql:`RETURNING` or an :sql:`ALTER TABLE`.
+
+ .. note::
+
+ Cursors are iterable objects, so just using the::
+
+ for record in cursor:
+ ...
+
+ syntax will iterate on the records in the current recordset.
+
+ .. autoattribute:: row_factory
+
+ The property affects the objects returned by the `fetchone()`,
+ `fetchmany()`, `fetchall()` methods. The default
+ (`~psycopg.rows.tuple_row`) returns a tuple for each record fetched.
+
+ See :ref:`row-factories` for details.
+
+ .. automethod:: fetchone
+ .. automethod:: fetchmany
+ .. automethod:: fetchall
+ .. automethod:: nextset
+ .. automethod:: scroll
+
+ .. attribute:: pgresult
+ :type: Optional[psycopg.pq.PGresult]
+
+ The result returned by the last query and currently exposed by the
+ cursor, if available, else `!None`.
+
+ It can be used to obtain low level info about the last query result
+ and to access to features not currently wrapped by Psycopg.
+
+
+ .. rubric:: Information about the data
+
+ .. autoattribute:: description
+
+ .. autoattribute:: statusmessage
+
+ This is the status tag you typically see in :program:`psql` after
+ a successful command, such as ``CREATE TABLE`` or ``UPDATE 42``.
+
+ .. autoattribute:: rowcount
+ .. autoattribute:: rownumber
+
+ .. attribute:: _query
+
+ An helper object used to convert queries and parameters before sending
+ them to PostgreSQL.
+
+ .. note::
+ This attribute is exposed because it might be helpful to debug
+ problems when the communication between Python and PostgreSQL
+ doesn't work as expected. For this reason, the attribute is
+ available when a query fails too.
+
+ .. warning::
+ You shouldn't consider it part of the public interface of the
+ object: it might change without warnings.
+
+ Except this warning, I guess.
+
+ If you would like to build reliable features using this object,
+ please get in touch so we can try and design an useful interface
+ for it.
+
+ Among the properties currently exposed by this object:
+
+ - `!query` (`!bytes`): the query effectively sent to PostgreSQL. It
+ will have Python placeholders (``%s``\-style) replaced with
+ PostgreSQL ones (``$1``, ``$2``\-style).
+
+ - `!params` (sequence of `!bytes`): the parameters passed to
+ PostgreSQL, adapted to the database format.
+
+ - `!types` (sequence of `!int`): the OID of the parameters passed to
+ PostgreSQL.
+
+ - `!formats` (sequence of `pq.Format`): whether the parameter format
+ is text or binary.
+
+
+The `!ClientCursor` class
+-------------------------
+
+.. seealso:: See :ref:`client-side-binding-cursors` for details.
+
+.. autoclass:: ClientCursor
+
+ This `Cursor` subclass has exactly the same interface of its parent class,
+ but, instead of sending query and parameters separately to the server, it
+ merges them on the client and sends them as a non-parametric query on the
+ server. This allows, for instance, to execute parametrized data definition
+ statements and other :ref:`problematic queries <server-side-binding>`.
+
+ .. versionadded:: 3.1
+
+ .. automethod:: mogrify
+
+ :param query: The query to execute.
+ :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed`
+ :param params: The parameters to pass to the query, if any.
+ :type params: Sequence or Mapping
+
+
+The `!ServerCursor` class
+--------------------------
+
+.. seealso:: See :ref:`server-side-cursors` for details.
+
+.. autoclass:: ServerCursor
+
+ This class also implements a `DBAPI-compliant interface`__. It is created
+ by `Connection.cursor()` specifying the `!name` parameter. Using this
+ object results in the creation of an equivalent PostgreSQL cursor in the
+ server. DBAPI-extension methods (such as `~Cursor.copy()` or
+ `~Cursor.stream()`) are not implemented on this object: use a normal
+ `Cursor` instead.
+
+ .. __: dbapi-cursor_
+
+ Most attribute and methods behave exactly like in `Cursor`, here are
+ documented the differences:
+
+ .. autoattribute:: name
+ .. autoattribute:: scrollable
+
+ .. seealso:: The PostgreSQL DECLARE_ statement documentation
+ for the description of :sql:`[NO] SCROLL`.
+
+ .. autoattribute:: withhold
+
+ .. seealso:: The PostgreSQL DECLARE_ statement documentation
+ for the description of :sql:`{WITH|WITHOUT} HOLD`.
+
+ .. _DECLARE: https://www.postgresql.org/docs/current/sql-declare.html
+
+
+ .. automethod:: close
+
+ .. warning:: Closing a server-side cursor is more important than
+ closing a client-side one because it also releases the resources
+ on the server, which otherwise might remain allocated until the
+ end of the session (memory, locks). Using the pattern::
+
+ with conn.cursor():
+ ...
+
+ is especially useful so that the cursor is closed at the end of
+ the block.
+
+ .. automethod:: execute
+
+ :param query: The query to execute.
+ :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed`
+ :param params: The parameters to pass to the query, if any.
+ :type params: Sequence or Mapping
+ :param binary: Specify whether the server should return data in binary
+ format (`!True`) or in text format (`!False`). By default
+ (`!None`) return data as requested by the cursor's `~Cursor.format`.
+
+ Create a server cursor with given `!name` and the `!query` in argument.
+
+ If using :sql:`DECLARE` is not appropriate (for instance because the
+ cursor is returned by calling a stored procedure) you can avoid to use
+ `!execute()`, crete the cursor in other ways, and use directly the
+ `!fetch*()` methods instead. See :ref:`cursor-steal` for an example.
+
+ Using `!execute()` more than once will close the previous cursor and
+ open a new one with the same name.
+
+ .. automethod:: executemany
+ .. automethod:: fetchone
+ .. automethod:: fetchmany
+ .. automethod:: fetchall
+
+ These methods use the FETCH_ SQL statement to retrieve some of the
+ records from the cursor's current position.
+
+ .. _FETCH: https://www.postgresql.org/docs/current/sql-fetch.html
+
+ .. note::
+
+ You can also iterate on the cursor to read its result one at
+ time with::
+
+ for record in cur:
+ ...
+
+ In this case, the records are not fetched one at time from the
+ server but they are retrieved in batches of `itersize` to reduce
+ the number of server roundtrips.
+
+ .. autoattribute:: itersize
+
+ Number of records to fetch at time when iterating on the cursor. The
+ default is 100.
+
+ .. automethod:: scroll
+
+ This method uses the MOVE_ SQL statement to move the current position
+ in the server-side cursor, which will affect following `!fetch*()`
+ operations. If you need to scroll backwards you should probably
+ call `~Connection.cursor()` using `scrollable=True`.
+
+ Note that PostgreSQL doesn't provide a reliable way to report when a
+ cursor moves out of bound, so the method might not raise `!IndexError`
+ when it happens, but it might rather stop at the cursor boundary.
+
+ .. _MOVE: https://www.postgresql.org/docs/current/sql-fetch.html
+
+
+The `!AsyncCursor` class
+------------------------
+
+.. autoclass:: AsyncCursor
+
+ This class implements a DBAPI-inspired interface, with all the blocking
+ methods implemented as coroutines. Unless specified otherwise,
+ non-blocking methods are shared with the `Cursor` class.
+
+ The following methods have the same behaviour of the matching `!Cursor`
+ methods, but should be called using the `await` keyword.
+
+ .. attribute:: connection
+ :type: AsyncConnection
+
+ .. automethod:: close
+
+ .. note::
+
+ You can use::
+
+ async with conn.cursor():
+ ...
+
+ to close the cursor automatically when the block is exited.
+
+ .. automethod:: execute
+ .. automethod:: executemany
+ .. automethod:: copy
+
+ .. note::
+
+ The method must be called with::
+
+ async with cursor.copy() as copy:
+ ...
+
+ .. automethod:: stream
+
+ .. note::
+
+ The method must be called with::
+
+ async for record in cursor.stream(query):
+ ...
+
+ .. automethod:: fetchone
+ .. automethod:: fetchmany
+ .. automethod:: fetchall
+ .. automethod:: scroll
+
+ .. note::
+
+ You can also use::
+
+ async for record in cursor:
+ ...
+
+ to iterate on the async cursor results.
+
+
+The `!AsyncClientCursor` class
+------------------------------
+
+.. autoclass:: AsyncClientCursor
+
+ This class is the `!async` equivalent of the `ClientCursor`. The
+ difference are the same shown in `AsyncCursor`.
+
+ .. versionadded:: 3.1
+
+
+
+The `!AsyncServerCursor` class
+------------------------------
+
+.. autoclass:: AsyncServerCursor
+
+ This class implements a DBAPI-inspired interface as the `AsyncCursor`
+ does, but wraps a server-side cursor like the `ServerCursor` class. It is
+ created by `AsyncConnection.cursor()` specifying the `!name` parameter.
+
+ The following are the methods exposing a different (async) interface from
+ the `ServerCursor` counterpart, but sharing the same semantics.
+
+ .. automethod:: close
+
+ .. note::
+ You can close the cursor automatically using::
+
+ async with conn.cursor("name") as cursor:
+ ...
+
+ .. automethod:: execute
+ .. automethod:: executemany
+ .. automethod:: fetchone
+ .. automethod:: fetchmany
+ .. automethod:: fetchall
+
+ .. note::
+
+ You can also iterate on the cursor using::
+
+ async for record in cur:
+ ...
+
+ .. automethod:: scroll
diff --git a/docs/api/dns.rst b/docs/api/dns.rst
new file mode 100644
index 0000000..186bde3
--- /dev/null
+++ b/docs/api/dns.rst
@@ -0,0 +1,145 @@
+`_dns` -- DNS resolution utilities
+==================================
+
+.. module:: psycopg._dns
+
+This module contains a few experimental utilities to interact with the DNS
+server before performing a connection.
+
+.. warning::
+ This module is experimental and its interface could change in the future,
+ without warning or respect for the version scheme. It is provided here to
+ allow experimentation before making it more stable.
+
+.. warning::
+ This module depends on the `dnspython`_ package. The package is currently
+ not installed automatically as a Psycopg dependency and must be installed
+ manually:
+
+ .. code:: sh
+
+ $ pip install "dnspython >= 2.1"
+
+ .. _dnspython: https://dnspython.readthedocs.io/
+
+
+.. function:: resolve_srv(params)
+
+ Apply SRV DNS lookup as defined in :RFC:`2782`.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`.
+ :type params: `!dict`
+ :return: An updated list of connection parameters.
+
+ For every host defined in the ``params["host"]`` list (comma-separated),
+ perform SRV lookup if the host is in the form ``_Service._Proto.Target``.
+ If lookup is successful, return a params dict with hosts and ports replaced
+ with the looked-up entries.
+
+ Raise `~psycopg.OperationalError` if no lookup is successful and no host
+ (looked up or unchanged) could be returned.
+
+ In addition to the rules defined by RFC 2782 about the host name pattern,
+ perform SRV lookup also if the the port is the string ``SRV`` (case
+ insensitive).
+
+ .. warning::
+ This is an experimental functionality.
+
+ .. note::
+ One possible way to use this function automatically is to subclass
+ `~psycopg.Connection`, extending the
+ `~psycopg.Connection._get_connection_params()` method::
+
+ import psycopg._dns # not imported automatically
+
+ class SrvCognizantConnection(psycopg.Connection):
+ @classmethod
+ def _get_connection_params(cls, conninfo, **kwargs):
+ params = super()._get_connection_params(conninfo, **kwargs)
+ params = psycopg._dns.resolve_srv(params)
+ return params
+
+ # The name will be resolved to db1.example.com
+ cnn = SrvCognizantConnection.connect("host=_postgres._tcp.db.psycopg.org")
+
+
+.. function:: resolve_srv_async(params)
+ :async:
+
+ Async equivalent of `resolve_srv()`.
+
+
+.. automethod:: psycopg.Connection._get_connection_params
+
+ .. warning::
+ This is an experimental method.
+
+ This method is a subclass hook allowing to manipulate the connection
+ parameters before performing the connection. Make sure to call the
+ `!super()` implementation before further manipulation of the arguments::
+
+ @classmethod
+ def _get_connection_params(cls, conninfo, **kwargs):
+ params = super()._get_connection_params(conninfo, **kwargs)
+ # do something with the params
+ return params
+
+
+.. automethod:: psycopg.AsyncConnection._get_connection_params
+
+ .. warning::
+ This is an experimental method.
+
+
+.. function:: resolve_hostaddr_async(params)
+ :async:
+
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ .. deprecated:: 3.1
+ The use of this function is not necessary anymore, because
+ `psycopg.AsyncConnection.connect()` performs non-blocking name
+ resolution automatically.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`.
+ :type params: `!dict`
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses dynamically.
+
+ The function may change the input ``host``, ``hostname``, ``port`` to allow
+ connecting without further DNS lookups, eventually removing hosts that are
+ not resolved, keeping the lists of hosts and ports consistent.
+
+ Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
+ host resolve, inconsistent lists length).
+
+ See `the PostgreSQL docs`__ for explanation of how these params are used,
+ and how they support multiple entries.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-PARAMKEYWORDS
+
+ .. warning::
+ Before psycopg 3.1, this function doesn't handle the ``/etc/hosts`` file.
+
+ .. note::
+ Starting from psycopg 3.1, a similar operation is performed
+ automatically by `!AsyncConnection._get_connection_params()`, so this
+ function is unneeded.
+
+ In psycopg 3.0, one possible way to use this function automatically is
+ to subclass `~psycopg.AsyncConnection`, extending the
+ `~psycopg.AsyncConnection._get_connection_params()` method::
+
+ import psycopg._dns # not imported automatically
+
+ class AsyncDnsConnection(psycopg.AsyncConnection):
+ @classmethod
+ async def _get_connection_params(cls, conninfo, **kwargs):
+ params = await super()._get_connection_params(conninfo, **kwargs)
+ params = await psycopg._dns.resolve_hostaddr_async(params)
+ return params
diff --git a/docs/api/errors.rst b/docs/api/errors.rst
new file mode 100644
index 0000000..2fca7c6
--- /dev/null
+++ b/docs/api/errors.rst
@@ -0,0 +1,540 @@
+`errors` -- Package exceptions
+==============================
+
+.. module:: psycopg.errors
+
+.. index::
+ single: Error; Class
+
+This module exposes objects to represent and examine database errors.
+
+
+.. currentmodule:: psycopg
+
+.. index::
+ single: Exceptions; DB-API
+
+.. _dbapi-exceptions:
+
+DB-API exceptions
+-----------------
+
+In compliance with the DB-API, all the exceptions raised by Psycopg
+derive from the following classes:
+
+.. parsed-literal::
+
+ `!Exception`
+ \|__ `Warning`
+ \|__ `Error`
+ \|__ `InterfaceError`
+ \|__ `DatabaseError`
+ \|__ `DataError`
+ \|__ `OperationalError`
+ \|__ `IntegrityError`
+ \|__ `InternalError`
+ \|__ `ProgrammingError`
+ \|__ `NotSupportedError`
+
+These classes are exposed both by this module and the root `psycopg` module.
+
+.. autoexception:: Error()
+
+ .. autoattribute:: diag
+ .. autoattribute:: sqlstate
+
+ The code of the error, if received from the server.
+
+ This attribute is also available as class attribute on the
+ :ref:`sqlstate-exceptions` classes.
+
+ .. autoattribute:: pgconn
+
+ Most likely it will be in `~psycopg.pq.ConnStatus.BAD` state;
+ however it might be useful to verify precisely what went wrong, for
+ instance checking the `~psycopg.pq.PGconn.needs_password` and
+ `~psycopg.pq.PGconn.used_password` attributes.
+
+ .. versionadded:: 3.1
+
+ .. autoattribute:: pgresult
+
+ .. versionadded:: 3.1
+
+
+.. autoexception:: Warning()
+.. autoexception:: InterfaceError()
+.. autoexception:: DatabaseError()
+.. autoexception:: DataError()
+.. autoexception:: OperationalError()
+.. autoexception:: IntegrityError()
+.. autoexception:: InternalError()
+.. autoexception:: ProgrammingError()
+.. autoexception:: NotSupportedError()
+
+
+Other Psycopg errors
+^^^^^^^^^^^^^^^^^^^^
+
+.. currentmodule:: psycopg.errors
+
+
+In addition to the standard DB-API errors, Psycopg defines a few more specific
+ones.
+
+.. autoexception:: ConnectionTimeout()
+.. autoexception:: PipelineAborted()
+
+
+
+.. index::
+ single: Exceptions; PostgreSQL
+
+Error diagnostics
+-----------------
+
+.. autoclass:: Diagnostic()
+
+ The object is available as the `~psycopg.Error`.\ `~psycopg.Error.diag`
+ attribute and is passed to the callback functions registered with
+ `~psycopg.Connection.add_notice_handler()`.
+
+ All the information available from the :pq:`PQresultErrorField()` function
+ are exposed as attributes by the object. For instance the `!severity`
+ attribute returns the `!PG_DIAG_SEVERITY` code. Please refer to the
+ PostgreSQL documentation for the meaning of all the attributes.
+
+ The attributes available are:
+
+ .. attribute::
+ column_name
+ constraint_name
+ context
+ datatype_name
+ internal_position
+ internal_query
+ message_detail
+ message_hint
+ message_primary
+ schema_name
+ severity
+ severity_nonlocalized
+ source_file
+ source_function
+ source_line
+ sqlstate
+ statement_position
+ table_name
+
+ A string with the error field if available; `!None` if not available.
+ The attribute value is available only for errors sent by the server:
+ not all the fields are available for all the errors and for all the
+ server versions.
+
+
+.. _sqlstate-exceptions:
+
+SQLSTATE exceptions
+-------------------
+
+Errors coming from a database server (as opposite as ones generated
+client-side, such as connection failed) usually have a 5-letters error code
+called SQLSTATE (available in the `~Diagnostic.sqlstate` attribute of the
+error's `~psycopg.Error.diag` attribute).
+
+Psycopg exposes a different class for each SQLSTATE value, allowing to
+write idiomatic error handling code according to specific conditions happening
+in the database:
+
+.. code-block:: python
+
+ try:
+ cur.execute("LOCK TABLE mytable IN ACCESS EXCLUSIVE MODE NOWAIT")
+ except psycopg.errors.LockNotAvailable:
+ locked = True
+
+The exception names are generated from the PostgreSQL source code and includes
+classes for every error defined by PostgreSQL in versions between 9.6 and 15.
+Every class in the module is named after what referred as "condition name" `in
+the documentation`__, converted to CamelCase: e.g. the error 22012,
+``division_by_zero`` is exposed by this module as the class `!DivisionByZero`.
+There is a handful of... exceptions to this rule, required for disambiguate
+name clashes: please refer to the :ref:`table below <exceptions-list>` for all
+the classes defined.
+
+.. __: https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
+
+Every exception class is a subclass of one of the :ref:`standard DB-API
+exception <dbapi-exceptions>`, thus exposing the `~psycopg.Error` interface.
+
+.. versionchanged:: 3.1.4
+ Added exceptions introduced in PostgreSQL 15.
+
+.. autofunction:: lookup
+
+ Example: if you have code using constant names or sql codes you can use
+ them to look up the exception class.
+
+ .. code-block:: python
+
+ try:
+ cur.execute("LOCK TABLE mytable IN ACCESS EXCLUSIVE MODE NOWAIT")
+ except psycopg.errors.lookup("UNDEFINED_TABLE"):
+ missing = True
+ except psycopg.errors.lookup("55P03"):
+ locked = True
+
+
+.. _exceptions-list:
+
+List of known exceptions
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+The following are all the SQLSTATE-related error classed defined by this
+module, together with the base DBAPI exception they derive from.
+
+.. autogenerated: start
+
+========= ================================================== ====================
+SQLSTATE Exception Base exception
+========= ================================================== ====================
+**Class 02** - No Data (this is also a warning class per the SQL standard)
+---------------------------------------------------------------------------------
+``02000`` `!NoData` `!DatabaseError`
+``02001`` `!NoAdditionalDynamicResultSetsReturned` `!DatabaseError`
+**Class 03** - SQL Statement Not Yet Complete
+---------------------------------------------------------------------------------
+``03000`` `!SqlStatementNotYetComplete` `!DatabaseError`
+**Class 08** - Connection Exception
+---------------------------------------------------------------------------------
+``08000`` `!ConnectionException` `!OperationalError`
+``08001`` `!SqlclientUnableToEstablishSqlconnection` `!OperationalError`
+``08003`` `!ConnectionDoesNotExist` `!OperationalError`
+``08004`` `!SqlserverRejectedEstablishmentOfSqlconnection` `!OperationalError`
+``08006`` `!ConnectionFailure` `!OperationalError`
+``08007`` `!TransactionResolutionUnknown` `!OperationalError`
+``08P01`` `!ProtocolViolation` `!OperationalError`
+**Class 09** - Triggered Action Exception
+---------------------------------------------------------------------------------
+``09000`` `!TriggeredActionException` `!DatabaseError`
+**Class 0A** - Feature Not Supported
+---------------------------------------------------------------------------------
+``0A000`` `!FeatureNotSupported` `!NotSupportedError`
+**Class 0B** - Invalid Transaction Initiation
+---------------------------------------------------------------------------------
+``0B000`` `!InvalidTransactionInitiation` `!DatabaseError`
+**Class 0F** - Locator Exception
+---------------------------------------------------------------------------------
+``0F000`` `!LocatorException` `!DatabaseError`
+``0F001`` `!InvalidLocatorSpecification` `!DatabaseError`
+**Class 0L** - Invalid Grantor
+---------------------------------------------------------------------------------
+``0L000`` `!InvalidGrantor` `!DatabaseError`
+``0LP01`` `!InvalidGrantOperation` `!DatabaseError`
+**Class 0P** - Invalid Role Specification
+---------------------------------------------------------------------------------
+``0P000`` `!InvalidRoleSpecification` `!DatabaseError`
+**Class 0Z** - Diagnostics Exception
+---------------------------------------------------------------------------------
+``0Z000`` `!DiagnosticsException` `!DatabaseError`
+``0Z002`` `!StackedDiagnosticsAccessedWithoutActiveHandler` `!DatabaseError`
+**Class 20** - Case Not Found
+---------------------------------------------------------------------------------
+``20000`` `!CaseNotFound` `!ProgrammingError`
+**Class 21** - Cardinality Violation
+---------------------------------------------------------------------------------
+``21000`` `!CardinalityViolation` `!ProgrammingError`
+**Class 22** - Data Exception
+---------------------------------------------------------------------------------
+``22000`` `!DataException` `!DataError`
+``22001`` `!StringDataRightTruncation` `!DataError`
+``22002`` `!NullValueNoIndicatorParameter` `!DataError`
+``22003`` `!NumericValueOutOfRange` `!DataError`
+``22004`` `!NullValueNotAllowed` `!DataError`
+``22005`` `!ErrorInAssignment` `!DataError`
+``22007`` `!InvalidDatetimeFormat` `!DataError`
+``22008`` `!DatetimeFieldOverflow` `!DataError`
+``22009`` `!InvalidTimeZoneDisplacementValue` `!DataError`
+``2200B`` `!EscapeCharacterConflict` `!DataError`
+``2200C`` `!InvalidUseOfEscapeCharacter` `!DataError`
+``2200D`` `!InvalidEscapeOctet` `!DataError`
+``2200F`` `!ZeroLengthCharacterString` `!DataError`
+``2200G`` `!MostSpecificTypeMismatch` `!DataError`
+``2200H`` `!SequenceGeneratorLimitExceeded` `!DataError`
+``2200L`` `!NotAnXmlDocument` `!DataError`
+``2200M`` `!InvalidXmlDocument` `!DataError`
+``2200N`` `!InvalidXmlContent` `!DataError`
+``2200S`` `!InvalidXmlComment` `!DataError`
+``2200T`` `!InvalidXmlProcessingInstruction` `!DataError`
+``22010`` `!InvalidIndicatorParameterValue` `!DataError`
+``22011`` `!SubstringError` `!DataError`
+``22012`` `!DivisionByZero` `!DataError`
+``22013`` `!InvalidPrecedingOrFollowingSize` `!DataError`
+``22014`` `!InvalidArgumentForNtileFunction` `!DataError`
+``22015`` `!IntervalFieldOverflow` `!DataError`
+``22016`` `!InvalidArgumentForNthValueFunction` `!DataError`
+``22018`` `!InvalidCharacterValueForCast` `!DataError`
+``22019`` `!InvalidEscapeCharacter` `!DataError`
+``2201B`` `!InvalidRegularExpression` `!DataError`
+``2201E`` `!InvalidArgumentForLogarithm` `!DataError`
+``2201F`` `!InvalidArgumentForPowerFunction` `!DataError`
+``2201G`` `!InvalidArgumentForWidthBucketFunction` `!DataError`
+``2201W`` `!InvalidRowCountInLimitClause` `!DataError`
+``2201X`` `!InvalidRowCountInResultOffsetClause` `!DataError`
+``22021`` `!CharacterNotInRepertoire` `!DataError`
+``22022`` `!IndicatorOverflow` `!DataError`
+``22023`` `!InvalidParameterValue` `!DataError`
+``22024`` `!UnterminatedCString` `!DataError`
+``22025`` `!InvalidEscapeSequence` `!DataError`
+``22026`` `!StringDataLengthMismatch` `!DataError`
+``22027`` `!TrimError` `!DataError`
+``2202E`` `!ArraySubscriptError` `!DataError`
+``2202G`` `!InvalidTablesampleRepeat` `!DataError`
+``2202H`` `!InvalidTablesampleArgument` `!DataError`
+``22030`` `!DuplicateJsonObjectKeyValue` `!DataError`
+``22031`` `!InvalidArgumentForSqlJsonDatetimeFunction` `!DataError`
+``22032`` `!InvalidJsonText` `!DataError`
+``22033`` `!InvalidSqlJsonSubscript` `!DataError`
+``22034`` `!MoreThanOneSqlJsonItem` `!DataError`
+``22035`` `!NoSqlJsonItem` `!DataError`
+``22036`` `!NonNumericSqlJsonItem` `!DataError`
+``22037`` `!NonUniqueKeysInAJsonObject` `!DataError`
+``22038`` `!SingletonSqlJsonItemRequired` `!DataError`
+``22039`` `!SqlJsonArrayNotFound` `!DataError`
+``2203A`` `!SqlJsonMemberNotFound` `!DataError`
+``2203B`` `!SqlJsonNumberNotFound` `!DataError`
+``2203C`` `!SqlJsonObjectNotFound` `!DataError`
+``2203D`` `!TooManyJsonArrayElements` `!DataError`
+``2203E`` `!TooManyJsonObjectMembers` `!DataError`
+``2203F`` `!SqlJsonScalarRequired` `!DataError`
+``2203G`` `!SqlJsonItemCannotBeCastToTargetType` `!DataError`
+``22P01`` `!FloatingPointException` `!DataError`
+``22P02`` `!InvalidTextRepresentation` `!DataError`
+``22P03`` `!InvalidBinaryRepresentation` `!DataError`
+``22P04`` `!BadCopyFileFormat` `!DataError`
+``22P05`` `!UntranslatableCharacter` `!DataError`
+``22P06`` `!NonstandardUseOfEscapeCharacter` `!DataError`
+**Class 23** - Integrity Constraint Violation
+---------------------------------------------------------------------------------
+``23000`` `!IntegrityConstraintViolation` `!IntegrityError`
+``23001`` `!RestrictViolation` `!IntegrityError`
+``23502`` `!NotNullViolation` `!IntegrityError`
+``23503`` `!ForeignKeyViolation` `!IntegrityError`
+``23505`` `!UniqueViolation` `!IntegrityError`
+``23514`` `!CheckViolation` `!IntegrityError`
+``23P01`` `!ExclusionViolation` `!IntegrityError`
+**Class 24** - Invalid Cursor State
+---------------------------------------------------------------------------------
+``24000`` `!InvalidCursorState` `!InternalError`
+**Class 25** - Invalid Transaction State
+---------------------------------------------------------------------------------
+``25000`` `!InvalidTransactionState` `!InternalError`
+``25001`` `!ActiveSqlTransaction` `!InternalError`
+``25002`` `!BranchTransactionAlreadyActive` `!InternalError`
+``25003`` `!InappropriateAccessModeForBranchTransaction` `!InternalError`
+``25004`` `!InappropriateIsolationLevelForBranchTransaction` `!InternalError`
+``25005`` `!NoActiveSqlTransactionForBranchTransaction` `!InternalError`
+``25006`` `!ReadOnlySqlTransaction` `!InternalError`
+``25007`` `!SchemaAndDataStatementMixingNotSupported` `!InternalError`
+``25008`` `!HeldCursorRequiresSameIsolationLevel` `!InternalError`
+``25P01`` `!NoActiveSqlTransaction` `!InternalError`
+``25P02`` `!InFailedSqlTransaction` `!InternalError`
+``25P03`` `!IdleInTransactionSessionTimeout` `!InternalError`
+**Class 26** - Invalid SQL Statement Name
+---------------------------------------------------------------------------------
+``26000`` `!InvalidSqlStatementName` `!ProgrammingError`
+**Class 27** - Triggered Data Change Violation
+---------------------------------------------------------------------------------
+``27000`` `!TriggeredDataChangeViolation` `!OperationalError`
+**Class 28** - Invalid Authorization Specification
+---------------------------------------------------------------------------------
+``28000`` `!InvalidAuthorizationSpecification` `!OperationalError`
+``28P01`` `!InvalidPassword` `!OperationalError`
+**Class 2B** - Dependent Privilege Descriptors Still Exist
+---------------------------------------------------------------------------------
+``2B000`` `!DependentPrivilegeDescriptorsStillExist` `!InternalError`
+``2BP01`` `!DependentObjectsStillExist` `!InternalError`
+**Class 2D** - Invalid Transaction Termination
+---------------------------------------------------------------------------------
+``2D000`` `!InvalidTransactionTermination` `!InternalError`
+**Class 2F** - SQL Routine Exception
+---------------------------------------------------------------------------------
+``2F000`` `!SqlRoutineException` `!OperationalError`
+``2F002`` `!ModifyingSqlDataNotPermitted` `!OperationalError`
+``2F003`` `!ProhibitedSqlStatementAttempted` `!OperationalError`
+``2F004`` `!ReadingSqlDataNotPermitted` `!OperationalError`
+``2F005`` `!FunctionExecutedNoReturnStatement` `!OperationalError`
+**Class 34** - Invalid Cursor Name
+---------------------------------------------------------------------------------
+``34000`` `!InvalidCursorName` `!ProgrammingError`
+**Class 38** - External Routine Exception
+---------------------------------------------------------------------------------
+``38000`` `!ExternalRoutineException` `!OperationalError`
+``38001`` `!ContainingSqlNotPermitted` `!OperationalError`
+``38002`` `!ModifyingSqlDataNotPermittedExt` `!OperationalError`
+``38003`` `!ProhibitedSqlStatementAttemptedExt` `!OperationalError`
+``38004`` `!ReadingSqlDataNotPermittedExt` `!OperationalError`
+**Class 39** - External Routine Invocation Exception
+---------------------------------------------------------------------------------
+``39000`` `!ExternalRoutineInvocationException` `!OperationalError`
+``39001`` `!InvalidSqlstateReturned` `!OperationalError`
+``39004`` `!NullValueNotAllowedExt` `!OperationalError`
+``39P01`` `!TriggerProtocolViolated` `!OperationalError`
+``39P02`` `!SrfProtocolViolated` `!OperationalError`
+``39P03`` `!EventTriggerProtocolViolated` `!OperationalError`
+**Class 3B** - Savepoint Exception
+---------------------------------------------------------------------------------
+``3B000`` `!SavepointException` `!OperationalError`
+``3B001`` `!InvalidSavepointSpecification` `!OperationalError`
+**Class 3D** - Invalid Catalog Name
+---------------------------------------------------------------------------------
+``3D000`` `!InvalidCatalogName` `!ProgrammingError`
+**Class 3F** - Invalid Schema Name
+---------------------------------------------------------------------------------
+``3F000`` `!InvalidSchemaName` `!ProgrammingError`
+**Class 40** - Transaction Rollback
+---------------------------------------------------------------------------------
+``40000`` `!TransactionRollback` `!OperationalError`
+``40001`` `!SerializationFailure` `!OperationalError`
+``40002`` `!TransactionIntegrityConstraintViolation` `!OperationalError`
+``40003`` `!StatementCompletionUnknown` `!OperationalError`
+``40P01`` `!DeadlockDetected` `!OperationalError`
+**Class 42** - Syntax Error or Access Rule Violation
+---------------------------------------------------------------------------------
+``42000`` `!SyntaxErrorOrAccessRuleViolation` `!ProgrammingError`
+``42501`` `!InsufficientPrivilege` `!ProgrammingError`
+``42601`` `!SyntaxError` `!ProgrammingError`
+``42602`` `!InvalidName` `!ProgrammingError`
+``42611`` `!InvalidColumnDefinition` `!ProgrammingError`
+``42622`` `!NameTooLong` `!ProgrammingError`
+``42701`` `!DuplicateColumn` `!ProgrammingError`
+``42702`` `!AmbiguousColumn` `!ProgrammingError`
+``42703`` `!UndefinedColumn` `!ProgrammingError`
+``42704`` `!UndefinedObject` `!ProgrammingError`
+``42710`` `!DuplicateObject` `!ProgrammingError`
+``42712`` `!DuplicateAlias` `!ProgrammingError`
+``42723`` `!DuplicateFunction` `!ProgrammingError`
+``42725`` `!AmbiguousFunction` `!ProgrammingError`
+``42803`` `!GroupingError` `!ProgrammingError`
+``42804`` `!DatatypeMismatch` `!ProgrammingError`
+``42809`` `!WrongObjectType` `!ProgrammingError`
+``42830`` `!InvalidForeignKey` `!ProgrammingError`
+``42846`` `!CannotCoerce` `!ProgrammingError`
+``42883`` `!UndefinedFunction` `!ProgrammingError`
+``428C9`` `!GeneratedAlways` `!ProgrammingError`
+``42939`` `!ReservedName` `!ProgrammingError`
+``42P01`` `!UndefinedTable` `!ProgrammingError`
+``42P02`` `!UndefinedParameter` `!ProgrammingError`
+``42P03`` `!DuplicateCursor` `!ProgrammingError`
+``42P04`` `!DuplicateDatabase` `!ProgrammingError`
+``42P05`` `!DuplicatePreparedStatement` `!ProgrammingError`
+``42P06`` `!DuplicateSchema` `!ProgrammingError`
+``42P07`` `!DuplicateTable` `!ProgrammingError`
+``42P08`` `!AmbiguousParameter` `!ProgrammingError`
+``42P09`` `!AmbiguousAlias` `!ProgrammingError`
+``42P10`` `!InvalidColumnReference` `!ProgrammingError`
+``42P11`` `!InvalidCursorDefinition` `!ProgrammingError`
+``42P12`` `!InvalidDatabaseDefinition` `!ProgrammingError`
+``42P13`` `!InvalidFunctionDefinition` `!ProgrammingError`
+``42P14`` `!InvalidPreparedStatementDefinition` `!ProgrammingError`
+``42P15`` `!InvalidSchemaDefinition` `!ProgrammingError`
+``42P16`` `!InvalidTableDefinition` `!ProgrammingError`
+``42P17`` `!InvalidObjectDefinition` `!ProgrammingError`
+``42P18`` `!IndeterminateDatatype` `!ProgrammingError`
+``42P19`` `!InvalidRecursion` `!ProgrammingError`
+``42P20`` `!WindowingError` `!ProgrammingError`
+``42P21`` `!CollationMismatch` `!ProgrammingError`
+``42P22`` `!IndeterminateCollation` `!ProgrammingError`
+**Class 44** - WITH CHECK OPTION Violation
+---------------------------------------------------------------------------------
+``44000`` `!WithCheckOptionViolation` `!ProgrammingError`
+**Class 53** - Insufficient Resources
+---------------------------------------------------------------------------------
+``53000`` `!InsufficientResources` `!OperationalError`
+``53100`` `!DiskFull` `!OperationalError`
+``53200`` `!OutOfMemory` `!OperationalError`
+``53300`` `!TooManyConnections` `!OperationalError`
+``53400`` `!ConfigurationLimitExceeded` `!OperationalError`
+**Class 54** - Program Limit Exceeded
+---------------------------------------------------------------------------------
+``54000`` `!ProgramLimitExceeded` `!OperationalError`
+``54001`` `!StatementTooComplex` `!OperationalError`
+``54011`` `!TooManyColumns` `!OperationalError`
+``54023`` `!TooManyArguments` `!OperationalError`
+**Class 55** - Object Not In Prerequisite State
+---------------------------------------------------------------------------------
+``55000`` `!ObjectNotInPrerequisiteState` `!OperationalError`
+``55006`` `!ObjectInUse` `!OperationalError`
+``55P02`` `!CantChangeRuntimeParam` `!OperationalError`
+``55P03`` `!LockNotAvailable` `!OperationalError`
+``55P04`` `!UnsafeNewEnumValueUsage` `!OperationalError`
+**Class 57** - Operator Intervention
+---------------------------------------------------------------------------------
+``57000`` `!OperatorIntervention` `!OperationalError`
+``57014`` `!QueryCanceled` `!OperationalError`
+``57P01`` `!AdminShutdown` `!OperationalError`
+``57P02`` `!CrashShutdown` `!OperationalError`
+``57P03`` `!CannotConnectNow` `!OperationalError`
+``57P04`` `!DatabaseDropped` `!OperationalError`
+``57P05`` `!IdleSessionTimeout` `!OperationalError`
+**Class 58** - System Error (errors external to PostgreSQL itself)
+---------------------------------------------------------------------------------
+``58000`` `!SystemError` `!OperationalError`
+``58030`` `!IoError` `!OperationalError`
+``58P01`` `!UndefinedFile` `!OperationalError`
+``58P02`` `!DuplicateFile` `!OperationalError`
+**Class 72** - Snapshot Failure
+---------------------------------------------------------------------------------
+``72000`` `!SnapshotTooOld` `!DatabaseError`
+**Class F0** - Configuration File Error
+---------------------------------------------------------------------------------
+``F0000`` `!ConfigFileError` `!OperationalError`
+``F0001`` `!LockFileExists` `!OperationalError`
+**Class HV** - Foreign Data Wrapper Error (SQL/MED)
+---------------------------------------------------------------------------------
+``HV000`` `!FdwError` `!OperationalError`
+``HV001`` `!FdwOutOfMemory` `!OperationalError`
+``HV002`` `!FdwDynamicParameterValueNeeded` `!OperationalError`
+``HV004`` `!FdwInvalidDataType` `!OperationalError`
+``HV005`` `!FdwColumnNameNotFound` `!OperationalError`
+``HV006`` `!FdwInvalidDataTypeDescriptors` `!OperationalError`
+``HV007`` `!FdwInvalidColumnName` `!OperationalError`
+``HV008`` `!FdwInvalidColumnNumber` `!OperationalError`
+``HV009`` `!FdwInvalidUseOfNullPointer` `!OperationalError`
+``HV00A`` `!FdwInvalidStringFormat` `!OperationalError`
+``HV00B`` `!FdwInvalidHandle` `!OperationalError`
+``HV00C`` `!FdwInvalidOptionIndex` `!OperationalError`
+``HV00D`` `!FdwInvalidOptionName` `!OperationalError`
+``HV00J`` `!FdwOptionNameNotFound` `!OperationalError`
+``HV00K`` `!FdwReplyHandle` `!OperationalError`
+``HV00L`` `!FdwUnableToCreateExecution` `!OperationalError`
+``HV00M`` `!FdwUnableToCreateReply` `!OperationalError`
+``HV00N`` `!FdwUnableToEstablishConnection` `!OperationalError`
+``HV00P`` `!FdwNoSchemas` `!OperationalError`
+``HV00Q`` `!FdwSchemaNotFound` `!OperationalError`
+``HV00R`` `!FdwTableNotFound` `!OperationalError`
+``HV010`` `!FdwFunctionSequenceError` `!OperationalError`
+``HV014`` `!FdwTooManyHandles` `!OperationalError`
+``HV021`` `!FdwInconsistentDescriptorInformation` `!OperationalError`
+``HV024`` `!FdwInvalidAttributeValue` `!OperationalError`
+``HV090`` `!FdwInvalidStringLengthOrBufferLength` `!OperationalError`
+``HV091`` `!FdwInvalidDescriptorFieldIdentifier` `!OperationalError`
+**Class P0** - PL/pgSQL Error
+---------------------------------------------------------------------------------
+``P0000`` `!PlpgsqlError` `!ProgrammingError`
+``P0001`` `!RaiseException` `!ProgrammingError`
+``P0002`` `!NoDataFound` `!ProgrammingError`
+``P0003`` `!TooManyRows` `!ProgrammingError`
+``P0004`` `!AssertFailure` `!ProgrammingError`
+**Class XX** - Internal Error
+---------------------------------------------------------------------------------
+``XX000`` `!InternalError_` `!InternalError`
+``XX001`` `!DataCorrupted` `!InternalError`
+``XX002`` `!IndexCorrupted` `!InternalError`
+========= ================================================== ====================
+
+.. autogenerated: end
+
+.. versionadded:: 3.1.4
+ Exception `!SqlJsonItemCannotBeCastToTargetType`, introduced in PostgreSQL
+ 15.
diff --git a/docs/api/index.rst b/docs/api/index.rst
new file mode 100644
index 0000000..b99550d
--- /dev/null
+++ b/docs/api/index.rst
@@ -0,0 +1,29 @@
+Psycopg 3 API
+=============
+
+.. _api:
+
+This sections is a reference for all the public objects exposed by the
+`psycopg` module. For a more conceptual description you can take a look at
+:ref:`basic` and :ref:`advanced`.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+ module
+ connections
+ cursors
+ copy
+ objects
+ sql
+ rows
+ errors
+ pool
+ conninfo
+ adapt
+ types
+ abc
+ pq
+ crdb
+ dns
diff --git a/docs/api/module.rst b/docs/api/module.rst
new file mode 100644
index 0000000..3c3d3c4
--- /dev/null
+++ b/docs/api/module.rst
@@ -0,0 +1,59 @@
+The `!psycopg` module
+=====================
+
+Psycopg implements the `Python Database DB API 2.0 specification`__. As such
+it also exposes the `module-level objects`__ required by the specifications.
+
+.. __: https://www.python.org/dev/peps/pep-0249/
+.. __: https://www.python.org/dev/peps/pep-0249/#module-interface
+
+.. module:: psycopg
+
+.. autofunction:: connect
+
+ This is an alias of the class method `Connection.connect`: see its
+ documentation for details.
+
+ If you need an asynchronous connection use `AsyncConnection.connect`
+ instead.
+
+
+.. rubric:: Exceptions
+
+The standard `DBAPI exceptions`__ are exposed both by the `!psycopg` module
+and by the `psycopg.errors` module. The latter also exposes more specific
+exceptions, mapping to the database error states (see
+:ref:`sqlstate-exceptions`).
+
+.. __: https://www.python.org/dev/peps/pep-0249/#exceptions
+
+.. parsed-literal::
+
+ `!Exception`
+ \|__ `Warning`
+ \|__ `Error`
+ \|__ `InterfaceError`
+ \|__ `DatabaseError`
+ \|__ `DataError`
+ \|__ `OperationalError`
+ \|__ `IntegrityError`
+ \|__ `InternalError`
+ \|__ `ProgrammingError`
+ \|__ `NotSupportedError`
+
+
+.. data:: adapters
+
+ The default adapters map establishing how Python and PostgreSQL types are
+ converted into each other.
+
+ This map is used as a template when new connections are created, using
+ `psycopg.connect()`. Its `~psycopg.adapt.AdaptersMap.types` attribute is a
+ `~psycopg.types.TypesRegistry` containing information about every
+ PostgreSQL builtin type, useful for adaptation customisation (see
+ :ref:`adaptation`)::
+
+ >>> psycopg.adapters.types["int4"]
+ <TypeInfo: int4 (oid: 23, array oid: 1007)>
+
+ :type: `~psycopg.adapt.AdaptersMap`
diff --git a/docs/api/objects.rst b/docs/api/objects.rst
new file mode 100644
index 0000000..f085ed9
--- /dev/null
+++ b/docs/api/objects.rst
@@ -0,0 +1,256 @@
+.. currentmodule:: psycopg
+
+Other top-level objects
+=======================
+
+Connection information
+----------------------
+
+.. autoclass:: ConnectionInfo()
+
+ The object is usually returned by `Connection.info`.
+
+ .. autoattribute:: dsn
+
+ .. note:: The `get_parameters()` method returns the same information
+ as a dict.
+
+ .. autoattribute:: status
+
+ The status can be one of a number of values. However, only two of
+ these are seen outside of an asynchronous connection procedure:
+ `~pq.ConnStatus.OK` and `~pq.ConnStatus.BAD`. A good connection to the
+ database has the status `!OK`. Ordinarily, an `!OK` status will remain
+ so until `Connection.close()`, but a communications failure might
+ result in the status changing to `!BAD` prematurely.
+
+ .. autoattribute:: transaction_status
+
+ The status can be `~pq.TransactionStatus.IDLE` (currently idle),
+ `~pq.TransactionStatus.ACTIVE` (a command is in progress),
+ `~pq.TransactionStatus.INTRANS` (idle, in a valid transaction block),
+ or `~pq.TransactionStatus.INERROR` (idle, in a failed transaction
+ block). `~pq.TransactionStatus.UNKNOWN` is reported if the connection
+ is bad. `!ACTIVE` is reported only when a query has been sent to the
+ server and not yet completed.
+
+ .. autoattribute:: pipeline_status
+
+ .. autoattribute:: backend_pid
+ .. autoattribute:: vendor
+
+ Normally it is `PostgreSQL`; it may be different if connected to
+ a different database.
+
+ .. versionadded:: 3.1
+
+ .. autoattribute:: server_version
+
+ The number is formed by converting the major, minor, and revision
+ numbers into two-decimal-digit numbers and appending them together.
+ Starting from PostgreSQL 10 the minor version was dropped, so the
+ second group of digits is always 00. For example, version 9.3.5 is
+ returned as 90305, version 10.2 as 100002.
+
+ .. autoattribute:: error_message
+
+ .. automethod:: get_parameters
+
+ .. note:: The `dsn` attribute returns the same information in the form
+ as a string.
+
+ .. autoattribute:: timezone
+
+ .. code:: pycon
+
+ >>> conn.info.timezone
+ zoneinfo.ZoneInfo(key='Europe/Rome')
+
+ .. autoattribute:: host
+
+ This can be a host name, an IP address, or a directory path if the
+ connection is via Unix socket. (The path case can be distinguished
+ because it will always be an absolute path, beginning with ``/``.)
+
+ .. autoattribute:: hostaddr
+
+ Only available if the libpq used is at least from PostgreSQL 12.
+ Raise `~psycopg.NotSupportedError` otherwise.
+
+ .. autoattribute:: port
+ .. autoattribute:: dbname
+ .. autoattribute:: user
+ .. autoattribute:: password
+ .. autoattribute:: options
+ .. automethod:: parameter_status
+
+ Example of parameters are ``server_version``,
+ ``standard_conforming_strings``... See :pq:`PQparameterStatus()` for
+ all the available parameters.
+
+ .. autoattribute:: encoding
+
+ The value returned is always normalized to the Python codec
+ `~codecs.CodecInfo.name`::
+
+ conn.execute("SET client_encoding TO LATIN9")
+ conn.info.encoding
+ 'iso8859-15'
+
+ A few PostgreSQL encodings are not available in Python and cannot be
+ selected (currently ``EUC_TW``, ``MULE_INTERNAL``). The PostgreSQL
+ ``SQL_ASCII`` encoding has the special meaning of "no encoding": see
+ :ref:`adapt-string` for details.
+
+ .. seealso::
+
+ The `PostgreSQL supported encodings`__.
+
+ .. __: https://www.postgresql.org/docs/current/multibyte.html
+
+
+The description `Column` object
+-------------------------------
+
+.. autoclass:: Column()
+
+ An object describing a column of data from a database result, `as described
+ by the DBAPI`__, so it can also be unpacked as a 7-items tuple.
+
+ The object is returned by `Cursor.description`.
+
+ .. __: https://www.python.org/dev/peps/pep-0249/#description
+
+ .. autoattribute:: name
+ .. autoattribute:: type_code
+ .. autoattribute:: display_size
+ .. autoattribute:: internal_size
+ .. autoattribute:: precision
+ .. autoattribute:: scale
+
+
+Notifications
+-------------
+
+.. autoclass:: Notify()
+
+ The object is usually returned by `Connection.notifies()`.
+
+ .. attribute:: channel
+ :type: str
+
+ The name of the channel on which the notification was received.
+
+ .. attribute:: payload
+ :type: str
+
+ The message attached to the notification.
+
+ .. attribute:: pid
+ :type: int
+
+ The PID of the backend process which sent the notification.
+
+
+Pipeline-related objects
+------------------------
+
+See :ref:`pipeline-mode` for details.
+
+.. autoclass:: Pipeline
+
+ This objects is returned by `Connection.pipeline()`.
+
+ .. automethod:: sync
+ .. automethod:: is_supported
+
+
+.. autoclass:: AsyncPipeline
+
+ This objects is returned by `AsyncConnection.pipeline()`.
+
+ .. automethod:: sync
+
+
+Transaction-related objects
+---------------------------
+
+See :ref:`transactions` for details about these objects.
+
+.. autoclass:: IsolationLevel
+ :members:
+
+ The value is usually used with the `Connection.isolation_level` property.
+
+ Check the PostgreSQL documentation for a description of the effects of the
+ different `levels of transaction isolation`__.
+
+ .. __: https://www.postgresql.org/docs/current/transaction-iso.html
+
+
+.. autoclass:: Transaction()
+
+ .. autoattribute:: savepoint_name
+ .. autoattribute:: connection
+
+
+.. autoclass:: AsyncTransaction()
+
+ .. autoattribute:: connection
+
+
+.. autoexception:: Rollback
+
+ It can be used as:
+
+ - ``raise Rollback``: roll back the operation that happened in the current
+ transaction block and continue the program after the block.
+
+ - ``raise Rollback()``: same effect as above
+
+ - :samp:`raise Rollback({tx})`: roll back any operation that happened in
+ the `Transaction` `!tx` (returned by a statement such as :samp:`with
+ conn.transaction() as {tx}:` and all the blocks nested within. The
+ program will continue after the `!tx` block.
+
+
+Two-Phase Commit related objects
+--------------------------------
+
+.. autoclass:: Xid()
+
+ See :ref:`two-phase-commit` for details.
+
+ .. autoattribute:: format_id
+
+ Format Identifier of the two-phase transaction.
+
+ .. autoattribute:: gtrid
+
+ Global Transaction Identifier of the two-phase transaction.
+
+ If the Xid doesn't follow the XA standard, it will be the PostgreSQL
+ ID of the transaction (in which case `format_id` and `bqual` will be
+ `!None`).
+
+ .. autoattribute:: bqual
+
+ Branch Qualifier of the two-phase transaction.
+
+ .. autoattribute:: prepared
+
+ Timestamp at which the transaction was prepared for commit.
+
+ Only available on transactions recovered by `~Connection.tpc_recover()`.
+
+ .. autoattribute:: owner
+
+ Named of the user that executed the transaction.
+
+ Only available on recovered transactions.
+
+ .. autoattribute:: database
+
+ Named of the database in which the transaction was executed.
+
+ Only available on recovered transactions.
diff --git a/docs/api/pool.rst b/docs/api/pool.rst
new file mode 100644
index 0000000..76ccc74
--- /dev/null
+++ b/docs/api/pool.rst
@@ -0,0 +1,331 @@
+`!psycopg_pool` -- Connection pool implementations
+==================================================
+
+.. index::
+ double: Connection; Pool
+
+.. module:: psycopg_pool
+
+A connection pool is an object used to create and maintain a limited amount of
+PostgreSQL connections, reducing the time requested by the program to obtain a
+working connection and allowing an arbitrary large number of concurrent
+threads or tasks to use a controlled amount of resources on the server. See
+:ref:`connection-pools` for more details and usage pattern.
+
+This package exposes a few connection pool classes:
+
+- `ConnectionPool` is a synchronous connection pool yielding
+ `~psycopg.Connection` objects and can be used by multithread applications.
+
+- `AsyncConnectionPool` has an interface similar to `!ConnectionPool`, but
+ with `asyncio` functions replacing blocking functions, and yields
+ `~psycopg.AsyncConnection` instances.
+
+- `NullConnectionPool` is a `!ConnectionPool` subclass exposing the same
+ interface of its parent, but not keeping any unused connection in its state.
+ See :ref:`null-pool` for details about related use cases.
+
+- `AsyncNullConnectionPool` has the same behaviour of the
+ `!NullConnectionPool`, but with the same async interface of the
+ `!AsyncConnectionPool`.
+
+.. note:: The `!psycopg_pool` package is distributed separately from the main
+ `psycopg` package: use ``pip install "psycopg[pool]"``, or ``pip install
+ psycopg_pool``, to make it available. See :ref:`pool-installation`.
+
+ The version numbers indicated in this page refer to the `!psycopg_pool`
+ package, not to `psycopg`.
+
+
+The `!ConnectionPool` class
+---------------------------
+
+.. autoclass:: ConnectionPool
+
+ This class implements a connection pool serving `~psycopg.Connection`
+ instances (or subclasses). The constructor has *alot* of arguments, but
+ only `!conninfo` and `!min_size` are the fundamental ones, all the other
+ arguments have meaningful defaults and can probably be tweaked later, if
+ required.
+
+ :param conninfo: The connection string. See
+ `~psycopg.Connection.connect()` for details.
+ :type conninfo: `!str`
+
+ :param min_size: The minimum number of connection the pool will hold. The
+ pool will actively try to create new connections if some
+ are lost (closed, broken) and will try to never go below
+ `!min_size`.
+ :type min_size: `!int`, default: 4
+
+ :param max_size: The maximum number of connections the pool will hold. If
+ `!None`, or equal to `!min_size`, the pool will not grow or
+ shrink. If larger than `!min_size`, the pool can grow if
+ more than `!min_size` connections are requested at the same
+ time and will shrink back after the extra connections have
+ been unused for more than `!max_idle` seconds.
+ :type max_size: `!int`, default: `!None`
+
+ :param kwargs: Extra arguments to pass to `!connect()`. Note that this is
+ *one dict argument* of the pool constructor, which is
+ expanded as `connect()` keyword parameters.
+
+ :type kwargs: `!dict`
+
+ :param connection_class: The class of the connections to serve. It should
+ be a `!Connection` subclass.
+ :type connection_class: `!type`, default: `~psycopg.Connection`
+
+ :param open: If `!True`, open the pool, creating the required connections,
+ on init. If `!False`, open the pool when `!open()` is called or
+ when the pool context is entered. See the `open()` method
+ documentation for more details.
+ :type open: `!bool`, default: `!True`
+
+ :param configure: A callback to configure a connection after creation.
+ Useful, for instance, to configure its adapters. If the
+ connection is used to run internal queries (to inspect the
+ database) make sure to close an eventual transaction
+ before leaving the function.
+ :type configure: `Callable[[Connection], None]`
+
+ :param reset: A callback to reset a function after it has been returned to
+ the pool. The connection is guaranteed to be passed to the
+ `!reset()` function in "idle" state (no transaction). When
+ leaving the `!reset()` function the connection must be left in
+ *idle* state, otherwise it is discarded.
+ :type reset: `Callable[[Connection], None]`
+
+ :param name: An optional name to give to the pool, useful, for instance, to
+ identify it in the logs if more than one pool is used. if not
+ specified pick a sequential name such as ``pool-1``,
+ ``pool-2``, etc.
+ :type name: `!str`
+
+ :param timeout: The default maximum time in seconds that a client can wait
+ to receive a connection from the pool (using `connection()`
+ or `getconn()`). Note that these methods allow to override
+ the `!timeout` default.
+ :type timeout: `!float`, default: 30 seconds
+
+ :param max_waiting: Maximum number of requests that can be queued to the
+ pool, after which new requests will fail, raising
+ `TooManyRequests`. 0 means no queue limit.
+ :type max_waiting: `!int`, default: 0
+
+ :param max_lifetime: The maximum lifetime of a connection in the pool, in
+ seconds. Connections used for longer get closed and
+ replaced by a new one. The amount is reduced by a
+ random 10% to avoid mass eviction.
+ :type max_lifetime: `!float`, default: 1 hour
+
+ :param max_idle: Maximum time, in seconds, that a connection can stay unused
+ in the pool before being closed, and the pool shrunk. This
+ only happens to connections more than `!min_size`, if
+ `!max_size` allowed the pool to grow.
+ :type max_idle: `!float`, default: 10 minutes
+
+ :param reconnect_timeout: Maximum time, in seconds, the pool will try to
+ create a connection. If a connection attempt
+ fails, the pool will try to reconnect a few
+ times, using an exponential backoff and some
+ random factor to avoid mass attempts. If repeated
+ attempts fail, after `!reconnect_timeout` second
+ the connection attempt is aborted and the
+ `!reconnect_failed()` callback invoked.
+ :type reconnect_timeout: `!float`, default: 5 minutes
+
+ :param reconnect_failed: Callback invoked if an attempt to create a new
+ connection fails for more than `!reconnect_timeout`
+ seconds. The user may decide, for instance, to
+ terminate the program (executing `sys.exit()`).
+ By default don't do anything: restart a new
+ connection attempt (if the number of connection
+ fell below `!min_size`).
+ :type reconnect_failed: ``Callable[[ConnectionPool], None]``
+
+ :param num_workers: Number of background worker threads used to maintain the
+ pool state. Background workers are used for example to
+ create new connections and to clean up connections when
+ they are returned to the pool.
+ :type num_workers: `!int`, default: 3
+
+ .. versionchanged:: 3.1
+
+ added `!open` parameter to init method.
+
+ .. note:: In a future version, the default value for the `!open` parameter
+ might be changed to `!False`. If you rely on this behaviour (e.g. if
+ you don't use the pool as a context manager) you might want to specify
+ this parameter explicitly.
+
+ .. automethod:: connection
+
+ .. code:: python
+
+ with my_pool.connection() as conn:
+ conn.execute(...)
+
+ # the connection is now back in the pool
+
+ .. automethod:: open
+
+ .. versionadded:: 3.1
+
+
+ .. automethod:: close
+
+ .. note::
+
+ The pool can be also used as a context manager, in which case it will
+ be opened (if necessary) on entering the block and closed on exiting it:
+
+ .. code:: python
+
+ with ConnectionPool(...) as pool:
+ # code using the pool
+
+ .. automethod:: wait
+
+ .. attribute:: name
+ :type: str
+
+ The name of the pool set on creation, or automatically generated if not
+ set.
+
+ .. autoattribute:: min_size
+ .. autoattribute:: max_size
+
+ The current minimum and maximum size of the pool. Use `resize()` to
+ change them at runtime.
+
+ .. automethod:: resize
+ .. automethod:: check
+ .. automethod:: get_stats
+ .. automethod:: pop_stats
+
+ See :ref:`pool-stats` for the metrics returned.
+
+ .. rubric:: Functionalities you may not need
+
+ .. automethod:: getconn
+ .. automethod:: putconn
+
+
+Pool exceptions
+---------------
+
+.. autoclass:: PoolTimeout()
+
+ Subclass of `~psycopg.OperationalError`
+
+.. autoclass:: PoolClosed()
+
+ Subclass of `~psycopg.OperationalError`
+
+.. autoclass:: TooManyRequests()
+
+ Subclass of `~psycopg.OperationalError`
+
+
+The `!AsyncConnectionPool` class
+--------------------------------
+
+`!AsyncConnectionPool` has a very similar interface to the `ConnectionPool`
+class but its blocking methods are implemented as `!async` coroutines. It
+returns instances of `~psycopg.AsyncConnection`, or of its subclass if
+specified so in the `!connection_class` parameter.
+
+Only the functions with different signature from `!ConnectionPool` are
+listed here.
+
+.. autoclass:: AsyncConnectionPool
+
+ :param connection_class: The class of the connections to serve. It should
+ be an `!AsyncConnection` subclass.
+ :type connection_class: `!type`, default: `~psycopg.AsyncConnection`
+
+ :param configure: A callback to configure a connection after creation.
+ :type configure: `async Callable[[AsyncConnection], None]`
+
+ :param reset: A callback to reset a function after it has been returned to
+ the pool.
+ :type reset: `async Callable[[AsyncConnection], None]`
+
+ .. automethod:: connection
+
+ .. code:: python
+
+ async with my_pool.connection() as conn:
+ await conn.execute(...)
+
+ # the connection is now back in the pool
+
+ .. automethod:: open
+ .. automethod:: close
+
+ .. note::
+
+ The pool can be also used as an async context manager, in which case it
+ will be opened (if necessary) on entering the block and closed on
+ exiting it:
+
+ .. code:: python
+
+ async with AsyncConnectionPool(...) as pool:
+ # code using the pool
+
+ All the other constructor parameters are the same of `!ConnectionPool`.
+
+ .. automethod:: wait
+ .. automethod:: resize
+ .. automethod:: check
+ .. automethod:: getconn
+ .. automethod:: putconn
+
+
+Null connection pools
+---------------------
+
+.. versionadded:: 3.1
+
+The `NullConnectionPool` is a `ConnectionPool` subclass which doesn't create
+connections preemptively and doesn't keep unused connections in its state. See
+:ref:`null-pool` for further details.
+
+The interface of the object is entirely compatible with its parent class. Its
+behaviour is similar, with the following differences:
+
+.. autoclass:: NullConnectionPool
+
+ All the other constructor parameters are the same as in `ConnectionPool`.
+
+ :param min_size: Always 0, cannot be changed.
+ :type min_size: `!int`, default: 0
+
+ :param max_size: If None or 0, create a new connection at every request,
+ without a maximum. If greater than 0, don't create more
+ than `!max_size` connections and queue the waiting clients.
+ :type max_size: `!int`, default: None
+
+ :param reset: It is only called when there are waiting clients in the
+ queue, before giving them a connection already open. If no
+ client is waiting, the connection is closed and discarded
+ without a fuss.
+ :type reset: `Callable[[Connection], None]`
+
+ :param max_idle: Ignored, as null pools don't leave idle connections
+ sitting around.
+
+ .. automethod:: wait
+ .. automethod:: resize
+ .. automethod:: check
+
+
+The `AsyncNullConnectionPool` is, similarly, an `AsyncConnectionPool` subclass
+with the same behaviour of the `NullConnectionPool`.
+
+.. autoclass:: AsyncNullConnectionPool
+
+ The interface is the same of its parent class `AsyncConnectionPool`. The
+ behaviour is different in the same way described for `NullConnectionPool`.
diff --git a/docs/api/pq.rst b/docs/api/pq.rst
new file mode 100644
index 0000000..3d9c033
--- /dev/null
+++ b/docs/api/pq.rst
@@ -0,0 +1,218 @@
+.. _psycopg.pq:
+
+`pq` -- libpq wrapper module
+============================
+
+.. index::
+ single: libpq
+
+.. module:: psycopg.pq
+
+Psycopg is built around the libpq_, the PostgreSQL client library, which
+performs most of the network communications and returns query results in C
+structures.
+
+.. _libpq: https://www.postgresql.org/docs/current/libpq.html
+
+The low-level functions of the library are exposed by the objects in the
+`!psycopg.pq` module.
+
+
+.. _pq-impl:
+
+``pq`` module implementations
+-----------------------------
+
+There are actually several implementations of the module, all offering the
+same interface. Current implementations are:
+
+- ``python``: a pure-python implementation, implemented using the `ctypes`
+ module. It is less performing than the others, but it doesn't need a C
+ compiler to install. It requires the libpq installed in the system.
+
+- ``c``: a C implementation of the libpq wrapper (more precisely, implemented
+ in Cython_). It is much better performing than the ``python``
+ implementation, however it requires development packages installed on the
+ client machine. It can be installed using the ``c`` extra, i.e. running
+ ``pip install "psycopg[c]"``.
+
+- ``binary``: a pre-compiled C implementation, bundled with all the required
+ libraries. It is the easiest option to deal with, fast to install and it
+ should require no development tool or client library, however it may be not
+ available for every platform. You can install it using the ``binary`` extra,
+ i.e. running ``pip install "psycopg[binary]"``.
+
+.. _Cython: https://cython.org/
+
+The implementation currently used is available in the `~psycopg.pq.__impl__`
+module constant.
+
+At import time, Psycopg 3 will try to use the best implementation available
+and will fail if none is usable. You can force the use of a specific
+implementation by exporting the env var :envvar:`PSYCOPG_IMPL`: importing the
+library will fail if the requested implementation is not available::
+
+ $ PSYCOPG_IMPL=c python -c "import psycopg"
+ Traceback (most recent call last):
+ ...
+ ImportError: couldn't import requested psycopg 'c' implementation: No module named 'psycopg_c'
+
+
+Module content
+--------------
+
+.. autodata:: __impl__
+
+ The choice of implementation is automatic but can be forced setting the
+ :envvar:`PSYCOPG_IMPL` env var.
+
+
+.. autofunction:: version
+
+ .. seealso:: the :pq:`PQlibVersion()` function
+
+
+.. autodata:: __build_version__
+
+.. autofunction:: error_message
+
+
+Objects wrapping libpq structures and functions
+-----------------------------------------------
+
+.. admonition:: TODO
+
+ finish documentation
+
+.. autoclass:: PGconn()
+
+ .. autoattribute:: pgconn_ptr
+ .. automethod:: get_cancel
+ .. autoattribute:: needs_password
+ .. autoattribute:: used_password
+
+ .. automethod:: encrypt_password
+
+ .. code:: python
+
+ >>> enc = conn.info.encoding
+ >>> encrypted = conn.pgconn.encrypt_password(password.encode(enc), rolename.encode(enc))
+ b'SCRAM-SHA-256$4096:...
+
+ .. automethod:: trace
+ .. automethod:: set_trace_flags
+ .. automethod:: untrace
+
+ .. code:: python
+
+ >>> conn.pgconn.trace(sys.stderr.fileno())
+ >>> conn.pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ >>> conn.execute("select now()")
+ F 13 Parse "" "BEGIN" 0
+ F 14 Bind "" "" 0 0 1 0
+ F 6 Describe P ""
+ F 9 Execute "" 0
+ F 4 Sync
+ B 4 ParseComplete
+ B 4 BindComplete
+ B 4 NoData
+ B 10 CommandComplete "BEGIN"
+ B 5 ReadyForQuery T
+ F 17 Query "select now()"
+ B 28 RowDescription 1 "now" NNNN 0 NNNN 8 -1 0
+ B 39 DataRow 1 29 '2022-09-14 14:12:16.648035+02'
+ B 13 CommandComplete "SELECT 1"
+ B 5 ReadyForQuery T
+ <psycopg.Cursor [TUPLES_OK] [INTRANS] (database=postgres) at 0x7f18a18ba040>
+ >>> conn.pgconn.untrace()
+
+
+.. autoclass:: PGresult()
+
+ .. autoattribute:: pgresult_ptr
+
+
+.. autoclass:: Conninfo
+.. autoclass:: Escaping
+
+.. autoclass:: PGcancel()
+ :members:
+
+
+Enumerations
+------------
+
+.. autoclass:: ConnStatus
+ :members:
+
+ There are other values in this enum, but only `OK` and `BAD` are seen
+ after a connection has been established. Other statuses might only be seen
+ during the connection phase and are considered internal.
+
+ .. seealso:: :pq:`PQstatus()` returns this value.
+
+
+.. autoclass:: PollingStatus
+ :members:
+
+ .. seealso:: :pq:`PQconnectPoll` for a description of these states.
+
+
+.. autoclass:: TransactionStatus
+ :members:
+
+ .. seealso:: :pq:`PQtransactionStatus` for a description of these states.
+
+
+.. autoclass:: ExecStatus
+ :members:
+
+ .. seealso:: :pq:`PQresultStatus` for a description of these states.
+
+
+.. autoclass:: PipelineStatus
+ :members:
+
+ .. seealso:: :pq:`PQpipelineStatus` for a description of these states.
+
+
+.. autoclass:: Format
+ :members:
+
+
+.. autoclass:: DiagnosticField
+
+ Available attributes:
+
+ .. attribute::
+ SEVERITY
+ SEVERITY_NONLOCALIZED
+ SQLSTATE
+ MESSAGE_PRIMARY
+ MESSAGE_DETAIL
+ MESSAGE_HINT
+ STATEMENT_POSITION
+ INTERNAL_POSITION
+ INTERNAL_QUERY
+ CONTEXT
+ SCHEMA_NAME
+ TABLE_NAME
+ COLUMN_NAME
+ DATATYPE_NAME
+ CONSTRAINT_NAME
+ SOURCE_FILE
+ SOURCE_LINE
+ SOURCE_FUNCTION
+
+ .. seealso:: :pq:`PQresultErrorField` for a description of these values.
+
+
+.. autoclass:: Ping
+ :members:
+
+ .. seealso:: :pq:`PQpingParams` for a description of these values.
+
+.. autoclass:: Trace
+ :members:
+
+ .. seealso:: :pq:`PQsetTraceFlags` for a description of these values.
diff --git a/docs/api/rows.rst b/docs/api/rows.rst
new file mode 100644
index 0000000..204f1ea
--- /dev/null
+++ b/docs/api/rows.rst
@@ -0,0 +1,74 @@
+.. _psycopg.rows:
+
+`rows` -- row factory implementations
+=====================================
+
+.. module:: psycopg.rows
+
+The module exposes a few generic `~psycopg.RowFactory` implementation, which
+can be used to retrieve data from the database in more complex structures than
+the basic tuples.
+
+Check out :ref:`row-factories` for information about how to use these objects.
+
+.. autofunction:: tuple_row
+.. autofunction:: dict_row
+.. autofunction:: namedtuple_row
+.. autofunction:: class_row
+
+ This is not a row factory, but rather a factory of row factories.
+ Specifying `!row_factory=class_row(MyClass)` will create connections and
+ cursors returning `!MyClass` objects on fetch.
+
+ Example::
+
+ from dataclasses import dataclass
+ import psycopg
+ from psycopg.rows import class_row
+
+ @dataclass
+ class Person:
+ first_name: str
+ last_name: str
+ age: int = None
+
+ conn = psycopg.connect()
+ cur = conn.cursor(row_factory=class_row(Person))
+
+ cur.execute("select 'John' as first_name, 'Smith' as last_name").fetchone()
+ # Person(first_name='John', last_name='Smith', age=None)
+
+.. autofunction:: args_row
+.. autofunction:: kwargs_row
+
+
+Formal rows protocols
+---------------------
+
+These objects can be used to describe your own rows adapter for static typing
+checks, such as mypy_.
+
+.. _mypy: https://mypy.readthedocs.io/
+
+
+.. autoclass:: psycopg.rows.RowMaker()
+
+ .. method:: __call__(values: Sequence[Any]) -> Row
+
+ Convert a sequence of values from the database to a finished object.
+
+
+.. autoclass:: psycopg.rows.RowFactory()
+
+ .. method:: __call__(cursor: Cursor[Row]) -> RowMaker[Row]
+
+ Inspect the result on a cursor and return a `RowMaker` to convert rows.
+
+.. autoclass:: psycopg.rows.AsyncRowFactory()
+
+.. autoclass:: psycopg.rows.BaseRowFactory()
+
+Note that it's easy to implement an object implementing both `!RowFactory` and
+`!AsyncRowFactory`: usually, everything you need to implement a row factory is
+to access the cursor's `~psycopg.Cursor.description`, which is provided by
+both the cursor flavours.
diff --git a/docs/api/sql.rst b/docs/api/sql.rst
new file mode 100644
index 0000000..6959fee
--- /dev/null
+++ b/docs/api/sql.rst
@@ -0,0 +1,151 @@
+`sql` -- SQL string composition
+===============================
+
+.. index::
+ double: Binding; Client-Side
+
+.. module:: psycopg.sql
+
+The module contains objects and functions useful to generate SQL dynamically,
+in a convenient and safe way. SQL identifiers (e.g. names of tables and
+fields) cannot be passed to the `~psycopg.Cursor.execute()` method like query
+arguments::
+
+ # This will not work
+ table_name = 'my_table'
+ cur.execute("INSERT INTO %s VALUES (%s, %s)", [table_name, 10, 20])
+
+The SQL query should be composed before the arguments are merged, for
+instance::
+
+ # This works, but it is not optimal
+ table_name = 'my_table'
+ cur.execute(
+ "INSERT INTO %s VALUES (%%s, %%s)" % table_name,
+ [10, 20])
+
+This sort of works, but it is an accident waiting to happen: the table name
+may be an invalid SQL literal and need quoting; even more serious is the
+security problem in case the table name comes from an untrusted source. The
+name should be escaped using `~psycopg.pq.Escaping.escape_identifier()`::
+
+ from psycopg.pq import Escaping
+
+ # This works, but it is not optimal
+ table_name = 'my_table'
+ cur.execute(
+ "INSERT INTO %s VALUES (%%s, %%s)" % Escaping.escape_identifier(table_name),
+ [10, 20])
+
+This is now safe, but it somewhat ad-hoc. In case, for some reason, it is
+necessary to include a value in the query string (as opposite as in a value)
+the merging rule is still different. It is also still relatively dangerous: if
+`!escape_identifier()` is forgotten somewhere, the program will usually work,
+but will eventually crash in the presence of a table or field name with
+containing characters to escape, or will present a potentially exploitable
+weakness.
+
+The objects exposed by the `!psycopg.sql` module allow generating SQL
+statements on the fly, separating clearly the variable parts of the statement
+from the query parameters::
+
+ from psycopg import sql
+
+ cur.execute(
+ sql.SQL("INSERT INTO {} VALUES (%s, %s)")
+ .format(sql.Identifier('my_table')),
+ [10, 20])
+
+
+Module usage
+------------
+
+Usually you should express the template of your query as an `SQL` instance
+with ``{}``\-style placeholders and use `~SQL.format()` to merge the variable
+parts into them, all of which must be `Composable` subclasses. You can still
+have ``%s``\-style placeholders in your query and pass values to
+`~psycopg.Cursor.execute()`: such value placeholders will be untouched by
+`!format()`::
+
+ query = sql.SQL("SELECT {field} FROM {table} WHERE {pkey} = %s").format(
+ field=sql.Identifier('my_name'),
+ table=sql.Identifier('some_table'),
+ pkey=sql.Identifier('id'))
+
+The resulting object is meant to be passed directly to cursor methods such as
+`~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
+`~psycopg.Cursor.copy()`, but can also be used to compose a query as a Python
+string, using the `~Composable.as_string()` method::
+
+ cur.execute(query, (42,))
+ full_query = query.as_string(cur)
+
+If part of your query is a variable sequence of arguments, such as a
+comma-separated list of field names, you can use the `SQL.join()` method to
+pass them to the query::
+
+ query = sql.SQL("SELECT {fields} FROM {table}").format(
+ fields=sql.SQL(',').join([
+ sql.Identifier('field1'),
+ sql.Identifier('field2'),
+ sql.Identifier('field3'),
+ ]),
+ table=sql.Identifier('some_table'))
+
+
+`!sql` objects
+--------------
+
+The `!sql` objects are in the following inheritance hierarchy:
+
+| `Composable`: the base class exposing the common interface
+| ``|__`` `SQL`: a literal snippet of an SQL query
+| ``|__`` `Identifier`: a PostgreSQL identifier or dot-separated sequence of identifiers
+| ``|__`` `Literal`: a value hardcoded into a query
+| ``|__`` `Placeholder`: a `%s`\ -style placeholder whose value will be added later e.g. by `~psycopg.Cursor.execute()`
+| ``|__`` `Composed`: a sequence of `!Composable` instances.
+
+
+.. autoclass:: Composable()
+
+ .. automethod:: as_bytes
+ .. automethod:: as_string
+
+
+.. autoclass:: SQL
+
+ .. versionchanged:: 3.1
+
+ The input object should be a `~typing.LiteralString`. See :pep:`675`
+ for details.
+
+ .. automethod:: format
+
+ .. automethod:: join
+
+
+.. autoclass:: Identifier
+
+.. autoclass:: Literal
+
+ .. versionchanged:: 3.1
+ Add a type cast to the representation if useful in ambiguous context
+ (e.g. ``'2000-01-01'::date``)
+
+.. autoclass:: Placeholder
+
+.. autoclass:: Composed
+
+ .. automethod:: join
+
+
+Utility functions
+-----------------
+
+.. autofunction:: quote
+
+.. data::
+ NULL
+ DEFAULT
+
+ `sql.SQL` objects often useful in queries.
diff --git a/docs/api/types.rst b/docs/api/types.rst
new file mode 100644
index 0000000..f04659e
--- /dev/null
+++ b/docs/api/types.rst
@@ -0,0 +1,168 @@
+.. currentmodule:: psycopg.types
+
+.. _psycopg.types:
+
+`!types` -- Types information and adapters
+==========================================
+
+.. module:: psycopg.types
+
+The `!psycopg.types` package exposes:
+
+- objects to describe PostgreSQL types, such as `TypeInfo`, `TypesRegistry`,
+ to help or :ref:`customise the types conversion <adaptation>`;
+
+- concrete implementations of `~psycopg.abc.Loader` and `~psycopg.abc.Dumper`
+ protocols to :ref:`handle builtin data types <types-adaptation>`;
+
+- helper objects to represent PostgreSQL data types which :ref:`don't have a
+ straightforward Python representation <extra-adaptation>`, such as
+ `~range.Range`.
+
+
+Types information
+-----------------
+
+The `TypeInfo` object describes simple information about a PostgreSQL data
+type, such as its name, oid and array oid. `!TypeInfo` subclasses may hold more
+information, for instance the components of a composite type.
+
+You can use `TypeInfo.fetch()` to query information from a database catalog,
+which is then used by helper functions, such as
+`~psycopg.types.hstore.register_hstore()`, to register adapters on types whose
+OID is not known upfront or to create more specialised adapters.
+
+The `!TypeInfo` object doesn't instruct Psycopg to convert a PostgreSQL type
+into a Python type: this is the role of a `~psycopg.abc.Loader`. However it
+can extend the behaviour of other adapters: if you create a loader for
+`!MyType`, using the `TypeInfo` information, Psycopg will be able to manage
+seamlessly arrays of `!MyType` or ranges and composite types using `!MyType`
+as a subtype.
+
+.. seealso:: :ref:`adaptation` describes how to convert from Python objects to
+ PostgreSQL types and back.
+
+.. code:: python
+
+ from psycopg.adapt import Loader
+ from psycopg.types import TypeInfo
+
+ t = TypeInfo.fetch(conn, "mytype")
+ t.register(conn)
+
+ for record in conn.execute("SELECT mytypearray FROM mytable"):
+ # records will return lists of "mytype" as string
+
+ class MyTypeLoader(Loader):
+ def load(self, data):
+ # parse the data and return a MyType instance
+
+ conn.adapters.register_loader("mytype", MyTypeLoader)
+
+ for record in conn.execute("SELECT mytypearray FROM mytable"):
+ # records will return lists of MyType instances
+
+
+.. autoclass:: TypeInfo
+
+ .. method:: fetch(conn, name)
+ :classmethod:
+
+ .. method:: fetch(aconn, name)
+ :classmethod:
+ :async:
+ :noindex:
+
+ Query a system catalog to read information about a type.
+
+ :param conn: the connection to query
+ :type conn: ~psycopg.Connection or ~psycopg.AsyncConnection
+ :param name: the name of the type to query. It can include a schema
+ name.
+ :type name: `!str` or `~psycopg.sql.Identifier`
+ :return: a `!TypeInfo` object (or subclass) populated with the type
+ information, `!None` if not found.
+
+ If the connection is async, `!fetch()` will behave as a coroutine and
+ the caller will need to `!await` on it to get the result::
+
+ t = await TypeInfo.fetch(aconn, "mytype")
+
+ .. automethod:: register
+
+ :param context: the context where the type is registered, for instance
+ a `~psycopg.Connection` or `~psycopg.Cursor`. `!None` registers
+ the `!TypeInfo` globally.
+ :type context: Optional[~psycopg.abc.AdaptContext]
+
+ Registering the `TypeInfo` in a context allows the adapters of that
+ context to look up type information: for instance it allows to
+ recognise automatically arrays of that type and load them from the
+ database as a list of the base type.
+
+
+In order to get information about dynamic PostgreSQL types, Psycopg offers a
+few `!TypeInfo` subclasses, whose `!fetch()` method can extract more complete
+information about the type, such as `~psycopg.types.composite.CompositeInfo`,
+`~psycopg.types.range.RangeInfo`, `~psycopg.types.multirange.MultirangeInfo`,
+`~psycopg.types.enum.EnumInfo`.
+
+`!TypeInfo` objects are collected in `TypesRegistry` instances, which help type
+information lookup. Every `~psycopg.adapt.AdaptersMap` exposes its type map on
+its `~psycopg.adapt.AdaptersMap.types` attribute.
+
+.. autoclass:: TypesRegistry
+
+ `!TypeRegistry` instances are typically exposed by
+ `~psycopg.adapt.AdaptersMap` objects in adapt contexts such as
+ `~psycopg.Connection` or `~psycopg.Cursor` (e.g. `!conn.adapters.types`).
+
+ The global registry, from which the others inherit from, is available as
+ `psycopg.adapters`\ `!.types`.
+
+ .. automethod:: __getitem__
+
+ .. code:: python
+
+ >>> import psycopg
+
+ >>> psycopg.adapters.types["text"]
+ <TypeInfo: text (oid: 25, array oid: 1009)>
+
+ >>> psycopg.adapters.types[23]
+ <TypeInfo: int4 (oid: 23, array oid: 1007)>
+
+ .. automethod:: get
+
+ .. automethod:: get_oid
+
+ .. code:: python
+
+ >>> psycopg.adapters.types.get_oid("text[]")
+ 1009
+
+ .. automethod:: get_by_subtype
+
+
+.. _json-adapters:
+
+JSON adapters
+-------------
+
+See :ref:`adapt-json` for details.
+
+.. currentmodule:: psycopg.types.json
+
+.. autoclass:: Json
+.. autoclass:: Jsonb
+
+Wrappers to signal to convert `!obj` to a json or jsonb PostgreSQL value.
+
+Any object supported by the underlying `!dumps()` function can be wrapped.
+
+If a `!dumps` function is passed to the wrapper, use it to dump the wrapped
+object. Otherwise use the function specified by `set_json_dumps()`.
+
+
+.. autofunction:: set_json_dumps
+.. autofunction:: set_json_loads
diff --git a/docs/basic/adapt.rst b/docs/basic/adapt.rst
new file mode 100644
index 0000000..1538327
--- /dev/null
+++ b/docs/basic/adapt.rst
@@ -0,0 +1,522 @@
+.. currentmodule:: psycopg
+
+.. index::
+ single: Adaptation
+ pair: Objects; Adaptation
+ single: Data types; Adaptation
+
+.. _types-adaptation:
+
+Adapting basic Python types
+===========================
+
+Many standard Python types are adapted into SQL and returned as Python
+objects when a query is executed.
+
+Converting the following data types between Python and PostgreSQL works
+out-of-the-box and doesn't require any configuration. In case you need to
+customise the conversion you should take a look at :ref:`adaptation`.
+
+
+.. index::
+ pair: Boolean; Adaptation
+
+.. _adapt-bool:
+
+Booleans adaptation
+-------------------
+
+Python `bool` values `!True` and `!False` are converted to the equivalent
+`PostgreSQL boolean type`__::
+
+ >>> cur.execute("SELECT %s, %s", (True, False))
+ # equivalent to "SELECT true, false"
+
+.. __: https://www.postgresql.org/docs/current/datatype-boolean.html
+
+
+.. index::
+ single: Adaptation; numbers
+ single: Integer; Adaptation
+ single: Float; Adaptation
+ single: Decimal; Adaptation
+
+.. _adapt-numbers:
+
+Numbers adaptation
+------------------
+
+.. seealso::
+
+ - `PostgreSQL numeric types
+ <https://www.postgresql.org/docs/current/static/datatype-numeric.html>`__
+
+- Python `int` values can be converted to PostgreSQL :sql:`smallint`,
+ :sql:`integer`, :sql:`bigint`, or :sql:`numeric`, according to their numeric
+ value. Psycopg will choose the smallest data type available, because
+ PostgreSQL can automatically cast a type up (e.g. passing a `smallint` where
+ PostgreSQL expect an `integer` is gladly accepted) but will not cast down
+ automatically (e.g. if a function has an :sql:`integer` argument, passing it
+ a :sql:`bigint` value will fail, even if the value is 1).
+
+- Python `float` values are converted to PostgreSQL :sql:`float8`.
+
+- Python `~decimal.Decimal` values are converted to PostgreSQL :sql:`numeric`.
+
+On the way back, smaller types (:sql:`int2`, :sql:`int4`, :sql:`float4`) are
+promoted to the larger Python counterpart.
+
+.. note::
+
+ Sometimes you may prefer to receive :sql:`numeric` data as `!float`
+ instead, for performance reason or ease of manipulation: you can configure
+ an adapter to :ref:`cast PostgreSQL numeric to Python float
+ <adapt-example-float>`. This of course may imply a loss of precision.
+
+
+.. index::
+ pair: Strings; Adaptation
+ single: Unicode; Adaptation
+ pair: Encoding; SQL_ASCII
+
+.. _adapt-string:
+
+Strings adaptation
+------------------
+
+.. seealso::
+
+ - `PostgreSQL character types
+ <https://www.postgresql.org/docs/current/datatype-character.html>`__
+
+Python `str` are converted to PostgreSQL string syntax, and PostgreSQL types
+such as :sql:`text` and :sql:`varchar` are converted back to Python `!str`:
+
+.. code:: python
+
+ conn = psycopg.connect()
+ conn.execute(
+ "INSERT INTO menu (id, entry) VALUES (%s, %s)",
+ (1, "Crème Brûlée at 4.99€"))
+ conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0]
+ 'Crème Brûlée at 4.99€'
+
+PostgreSQL databases `have an encoding`__, and `the session has an encoding`__
+too, exposed in the `!Connection.info.`\ `~ConnectionInfo.encoding`
+attribute. If your database and connection are in UTF-8 encoding you will
+likely have no problem, otherwise you will have to make sure that your
+application only deals with the non-ASCII chars that the database can handle;
+failing to do so may result in encoding/decoding errors:
+
+.. __: https://www.postgresql.org/docs/current/sql-createdatabase.html
+.. __: https://www.postgresql.org/docs/current/multibyte.html
+
+.. code:: python
+
+ # The encoding is set at connection time according to the db configuration
+ conn.info.encoding
+ 'utf-8'
+
+ # The Latin-9 encoding can manage some European accented letters
+ # and the Euro symbol
+ conn.execute("SET client_encoding TO LATIN9")
+ conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0]
+ 'Crème Brûlée at 4.99€'
+
+ # The Latin-1 encoding doesn't have a representation for the Euro symbol
+ conn.execute("SET client_encoding TO LATIN1")
+ conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0]
+ # Traceback (most recent call last)
+ # ...
+ # UntranslatableCharacter: character with byte sequence 0xe2 0x82 0xac
+ # in encoding "UTF8" has no equivalent in encoding "LATIN1"
+
+In rare cases you may have strings with unexpected encodings in the database.
+Using the ``SQL_ASCII`` client encoding will disable decoding of the data
+coming from the database, which will be returned as `bytes`:
+
+.. code:: python
+
+ conn.execute("SET client_encoding TO SQL_ASCII")
+ conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0]
+ b'Cr\xc3\xa8me Br\xc3\xbbl\xc3\xa9e at 4.99\xe2\x82\xac'
+
+Alternatively you can cast the unknown encoding data to :sql:`bytea` to
+retrieve it as bytes, leaving other strings unaltered: see :ref:`adapt-binary`
+
+Note that PostgreSQL text cannot contain the ``0x00`` byte. If you need to
+store Python strings that may contain binary zeros you should use a
+:sql:`bytea` field.
+
+
+.. index::
+ single: bytea; Adaptation
+ single: bytes; Adaptation
+ single: bytearray; Adaptation
+ single: memoryview; Adaptation
+ single: Binary string
+
+.. _adapt-binary:
+
+Binary adaptation
+-----------------
+
+Python types representing binary objects (`bytes`, `bytearray`, `memoryview`)
+are converted by default to :sql:`bytea` fields. By default data received is
+returned as `!bytes`.
+
+If you are storing large binary data in bytea fields (such as binary documents
+or images) you should probably use the binary format to pass and return
+values, otherwise binary data will undergo `ASCII escaping`__, taking some CPU
+time and more bandwidth. See :ref:`binary-data` for details.
+
+.. __: https://www.postgresql.org/docs/current/datatype-binary.html
+
+
+.. _adapt-date:
+
+Date/time types adaptation
+--------------------------
+
+.. seealso::
+
+ - `PostgreSQL date/time types
+ <https://www.postgresql.org/docs/current/datatype-datetime.html>`__
+
+- Python `~datetime.date` objects are converted to PostgreSQL :sql:`date`.
+- Python `~datetime.datetime` objects are converted to PostgreSQL
+ :sql:`timestamp` (if they don't have a `!tzinfo` set) or :sql:`timestamptz`
+ (if they do).
+- Python `~datetime.time` objects are converted to PostgreSQL :sql:`time`
+ (if they don't have a `!tzinfo` set) or :sql:`timetz` (if they do).
+- Python `~datetime.timedelta` objects are converted to PostgreSQL
+ :sql:`interval`.
+
+PostgreSQL :sql:`timestamptz` values are returned with a timezone set to the
+`connection TimeZone setting`__, which is available as a Python
+`~zoneinfo.ZoneInfo` object in the `!Connection.info`.\ `~ConnectionInfo.timezone`
+attribute::
+
+ >>> conn.info.timezone
+ zoneinfo.ZoneInfo(key='Europe/London')
+
+ >>> conn.execute("select '2048-07-08 12:00'::timestamptz").fetchone()[0]
+ datetime.datetime(2048, 7, 8, 12, 0, tzinfo=zoneinfo.ZoneInfo(key='Europe/London'))
+
+.. note::
+ PostgreSQL :sql:`timestamptz` doesn't store "a timestamp with a timezone
+ attached": it stores a timestamp always in UTC, which is converted, on
+ output, to the connection TimeZone setting::
+
+ >>> conn.execute("SET TIMEZONE to 'Europe/Rome'") # UTC+2 in summer
+
+ >>> conn.execute("SELECT '2042-07-01 12:00Z'::timestamptz").fetchone()[0] # UTC input
+ datetime.datetime(2042, 7, 1, 14, 0, tzinfo=zoneinfo.ZoneInfo(key='Europe/Rome'))
+
+ Check out the `PostgreSQL documentation about timezones`__ for all the
+ details.
+
+ .. __: https://www.postgresql.org/docs/current/datatype-datetime.html
+ #DATATYPE-TIMEZONES
+
+.. __: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-TIMEZONE
+
+
+.. _adapt-json:
+
+JSON adaptation
+---------------
+
+Psycopg can map between Python objects and PostgreSQL `json/jsonb
+types`__, allowing to customise the load and dump function used.
+
+.. __: https://www.postgresql.org/docs/current/datatype-json.html
+
+Because several Python objects could be considered JSON (dicts, lists,
+scalars, even date/time if using a dumps function customised to use them),
+Psycopg requires you to wrap the object to dump as JSON into a wrapper:
+either `psycopg.types.json.Json` or `~psycopg.types.json.Jsonb`.
+
+.. code:: python
+
+ from psycopg.types.json import Jsonb
+
+ thing = {"foo": ["bar", 42]}
+ conn.execute("INSERT INTO mytable VALUES (%s)", [Jsonb(thing)])
+
+By default Psycopg uses the standard library `json.dumps` and `json.loads`
+functions to serialize and de-serialize Python objects to JSON. If you want to
+customise how serialization happens, for instance changing serialization
+parameters or using a different JSON library, you can specify your own
+functions using the `psycopg.types.json.set_json_dumps()` and
+`~psycopg.types.json.set_json_loads()` functions, to apply either globally or
+to a specific context (connection or cursor).
+
+.. code:: python
+
+ from functools import partial
+ from psycopg.types.json import Jsonb, set_json_dumps, set_json_loads
+ import ujson
+
+ # Use a faster dump function
+ set_json_dumps(ujson.dumps)
+
+ # Return floating point values as Decimal, just in one connection
+ set_json_loads(partial(json.loads, parse_float=Decimal), conn)
+
+ conn.execute("SELECT %s", [Jsonb({"value": 123.45})]).fetchone()[0]
+ # {'value': Decimal('123.45')}
+
+If you need an even more specific dump customisation only for certain objects
+(including different configurations in the same query) you can specify a
+`!dumps` parameter in the
+`~psycopg.types.json.Json`/`~psycopg.types.json.Jsonb` wrapper, which will
+take precedence over what is specified by `!set_json_dumps()`.
+
+.. code:: python
+
+ from uuid import UUID, uuid4
+
+ class UUIDEncoder(json.JSONEncoder):
+ """A JSON encoder which can dump UUID."""
+ def default(self, obj):
+ if isinstance(obj, UUID):
+ return str(obj)
+ return json.JSONEncoder.default(self, obj)
+
+ uuid_dumps = partial(json.dumps, cls=UUIDEncoder)
+ obj = {"uuid": uuid4()}
+ cnn.execute("INSERT INTO objs VALUES %s", [Json(obj, dumps=uuid_dumps)])
+ # will insert: {'uuid': '0a40799d-3980-4c65-8315-2956b18ab0e1'}
+
+
+.. _adapt-list:
+
+Lists adaptation
+----------------
+
+Python `list` objects are adapted to `PostgreSQL arrays`__ and back. Only
+lists containing objects of the same type can be dumped to PostgreSQL (but the
+list may contain `!None` elements).
+
+.. __: https://www.postgresql.org/docs/current/arrays.html
+
+.. note::
+
+ If you have a list of values which you want to use with the :sql:`IN`
+ operator... don't. It won't work (neither with a list nor with a tuple)::
+
+ >>> conn.execute("SELECT * FROM mytable WHERE id IN %s", [[10,20,30]])
+ Traceback (most recent call last):
+ File "<stdin>", line 1, in <module>
+ psycopg.errors.SyntaxError: syntax error at or near "$1"
+ LINE 1: SELECT * FROM mytable WHERE id IN $1
+ ^
+
+ What you want to do instead is to use the `'= ANY()' expression`__ and pass
+ the values as a list (not a tuple).
+
+ >>> conn.execute("SELECT * FROM mytable WHERE id = ANY(%s)", [[10,20,30]])
+
+ This has also the advantage of working with an empty list, whereas ``IN
+ ()`` is not valid SQL.
+
+ .. __: https://www.postgresql.org/docs/current/functions-comparisons.html
+ #id-1.5.8.30.16
+
+
+.. _adapt-uuid:
+
+UUID adaptation
+---------------
+
+Python `uuid.UUID` objects are adapted to PostgreSQL `UUID type`__ and back::
+
+ >>> conn.execute("select gen_random_uuid()").fetchone()[0]
+ UUID('97f0dd62-3bd2-459e-89b8-a5e36ea3c16c')
+
+ >>> from uuid import uuid4
+ >>> conn.execute("select gen_random_uuid() = %s", [uuid4()]).fetchone()[0]
+ False # long shot
+
+.. __: https://www.postgresql.org/docs/current/datatype-uuid.html
+
+
+.. _adapt-network:
+
+Network data types adaptation
+-----------------------------
+
+Objects from the `ipaddress` module are converted to PostgreSQL `network
+address types`__:
+
+- `~ipaddress.IPv4Address`, `~ipaddress.IPv4Interface` objects are converted
+ to the PostgreSQL :sql:`inet` type. On the way back, :sql:`inet` values
+ indicating a single address are converted to `!IPv4Address`, otherwise they
+ are converted to `!IPv4Interface`
+
+- `~ipaddress.IPv4Network` objects are converted to the :sql:`cidr` type and
+ back.
+
+- `~ipaddress.IPv6Address`, `~ipaddress.IPv6Interface`,
+ `~ipaddress.IPv6Network` objects follow the same rules, with IPv6
+ :sql:`inet` and :sql:`cidr` values.
+
+.. __: https://www.postgresql.org/docs/current/datatype-net-types.html#DATATYPE-CIDR
+
+.. code:: python
+
+ >>> conn.execute("select '192.168.0.1'::inet, '192.168.0.1/24'::inet").fetchone()
+ (IPv4Address('192.168.0.1'), IPv4Interface('192.168.0.1/24'))
+
+ >>> conn.execute("select '::ffff:1.2.3.0/120'::cidr").fetchone()[0]
+ IPv6Network('::ffff:102:300/120')
+
+
+.. _adapt-enum:
+
+Enum adaptation
+---------------
+
+.. versionadded:: 3.1
+
+Psycopg can adapt Python `~enum.Enum` subclasses into PostgreSQL enum types
+(created with the |CREATE TYPE AS ENUM|_ command).
+
+.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)`
+.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html
+
+In order to set up a bidirectional enum mapping, you should get information
+about the PostgreSQL enum using the `~types.enum.EnumInfo` class and
+register it using `~types.enum.register_enum()`. The behaviour of unregistered
+and registered enums is different.
+
+- If the enum is not registered with `register_enum()`:
+
+ - Pure `!Enum` classes are dumped as normal strings, using their member
+ names as value. The unknown oid is used, so PostgreSQL should be able to
+ use this string in most contexts (such as an enum or a text field).
+
+ .. versionchanged:: 3.1
+ In previous version dumping pure enums is not supported and raise a
+ "cannot adapt" error.
+
+ - Mix-in enums are dumped according to their mix-in type (because a `class
+ MyIntEnum(int, Enum)` is more specifically an `!int` than an `!Enum`, so
+ it's dumped by default according to `!int` rules).
+
+ - PostgreSQL enums are loaded as Python strings. If you want to load arrays
+ of such enums you will have to find their OIDs using `types.TypeInfo.fetch()`
+ and register them using `~types.TypeInfo.register()`.
+
+- If the enum is registered (using `~types.enum.EnumInfo`\ `!.fetch()` and
+ `~types.enum.register_enum()`):
+
+ - Enums classes, both pure and mixed-in, are dumped by name.
+
+ - The registered PostgreSQL enum is loaded back as the registered Python
+ enum members.
+
+.. autoclass:: psycopg.types.enum.EnumInfo
+
+ `!EnumInfo` is a subclass of `~psycopg.types.TypeInfo`: refer to the
+ latter's documentation for generic usage, especially the
+ `~psycopg.types.TypeInfo.fetch()` method.
+
+ .. attribute:: labels
+
+ After `~psycopg.types.TypeInfo.fetch()`, it contains the labels defined
+ in the PostgreSQL enum type.
+
+ .. attribute:: enum
+
+ After `register_enum()` is called, it will contain the Python type
+ mapping to the registered enum.
+
+.. autofunction:: psycopg.types.enum.register_enum
+
+ After registering, fetching data of the registered enum will cast
+ PostgreSQL enum labels into corresponding Python enum members.
+
+ If no `!enum` is specified, a new `Enum` is created based on
+ PostgreSQL enum labels.
+
+Example::
+
+ >>> from enum import Enum, auto
+ >>> from psycopg.types.enum import EnumInfo, register_enum
+
+ >>> class UserRole(Enum):
+ ... ADMIN = auto()
+ ... EDITOR = auto()
+ ... GUEST = auto()
+
+ >>> conn.execute("CREATE TYPE user_role AS ENUM ('ADMIN', 'EDITOR', 'GUEST')")
+
+ >>> info = EnumInfo.fetch(conn, "user_role")
+ >>> register_enum(info, conn, UserRole)
+
+ >>> some_editor = info.enum.EDITOR
+ >>> some_editor
+ <UserRole.EDITOR: 2>
+
+ >>> conn.execute(
+ ... "SELECT pg_typeof(%(editor)s), %(editor)s",
+ ... {"editor": some_editor}
+ ... ).fetchone()
+ ('user_role', <UserRole.EDITOR: 2>)
+
+ >>> conn.execute(
+ ... "SELECT ARRAY[%s, %s]",
+ ... [UserRole.ADMIN, UserRole.GUEST]
+ ... ).fetchone()
+ [<UserRole.ADMIN: 1>, <UserRole.GUEST: 3>]
+
+If the Python and the PostgreSQL enum don't match 1:1 (for instance if members
+have a different name, or if more than one Python enum should map to the same
+PostgreSQL enum, or vice versa), you can specify the exceptions using the
+`!mapping` parameter.
+
+`!mapping` should be a dictionary with Python enum members as keys and the
+matching PostgreSQL enum labels as values, or a list of `(member, label)`
+pairs with the same meaning (useful when some members are repeated). Order
+matters: if an element on either side is specified more than once, the last
+pair in the sequence will take precedence::
+
+ # Legacy roles, defined in medieval times.
+ >>> conn.execute(
+ ... "CREATE TYPE abbey_role AS ENUM ('ABBOT', 'SCRIBE', 'MONK', 'GUEST')")
+
+ >>> info = EnumInfo.fetch(conn, "abbey_role")
+ >>> register_enum(info, conn, UserRole, mapping=[
+ ... (UserRole.ADMIN, "ABBOT"),
+ ... (UserRole.EDITOR, "SCRIBE"),
+ ... (UserRole.EDITOR, "MONK")])
+
+ >>> conn.execute("SELECT '{ABBOT,SCRIBE,MONK,GUEST}'::abbey_role[]").fetchone()[0]
+ [<UserRole.ADMIN: 1>,
+ <UserRole.EDITOR: 2>,
+ <UserRole.EDITOR: 2>,
+ <UserRole.GUEST: 3>]
+
+ >>> conn.execute("SELECT %s::text[]", [list(UserRole)]).fetchone()[0]
+ ['ABBOT', 'MONK', 'GUEST']
+
+A particularly useful case is when the PostgreSQL labels match the *values* of
+a `!str`\-based Enum. In this case it is possible to use something like ``{m:
+m.value for m in enum}`` as mapping::
+
+ >>> class LowercaseRole(str, Enum):
+ ... ADMIN = "admin"
+ ... EDITOR = "editor"
+ ... GUEST = "guest"
+
+ >>> conn.execute(
+ ... "CREATE TYPE lowercase_role AS ENUM ('admin', 'editor', 'guest')")
+
+ >>> info = EnumInfo.fetch(conn, "lowercase_role")
+ >>> register_enum(
+ ... info, conn, LowercaseRole, mapping={m: m.value for m in LowercaseRole})
+
+ >>> conn.execute("SELECT 'editor'::lowercase_role").fetchone()[0]
+ <LowercaseRole.EDITOR: 'editor'>
diff --git a/docs/basic/copy.rst b/docs/basic/copy.rst
new file mode 100644
index 0000000..2bb4498
--- /dev/null
+++ b/docs/basic/copy.rst
@@ -0,0 +1,212 @@
+.. currentmodule:: psycopg
+
+.. index::
+ pair: COPY; SQL command
+
+.. _copy:
+
+Using COPY TO and COPY FROM
+===========================
+
+Psycopg allows to operate with `PostgreSQL COPY protocol`__. :sql:`COPY` is
+one of the most efficient ways to load data into the database (and to modify
+it, with some SQL creativity).
+
+.. __: https://www.postgresql.org/docs/current/sql-copy.html
+
+Copy is supported using the `Cursor.copy()` method, passing it a query of the
+form :sql:`COPY ... FROM STDIN` or :sql:`COPY ... TO STDOUT`, and managing the
+resulting `Copy` object in a `!with` block:
+
+.. code:: python
+
+ with cursor.copy("COPY table_name (col1, col2) FROM STDIN") as copy:
+ # pass data to the 'copy' object using write()/write_row()
+
+You can compose a COPY statement dynamically by using objects from the
+`psycopg.sql` module:
+
+.. code:: python
+
+ with cursor.copy(
+ sql.SQL("COPY {} TO STDOUT").format(sql.Identifier("table_name"))
+ ) as copy:
+ # read data from the 'copy' object using read()/read_row()
+
+.. versionchanged:: 3.1
+
+ You can also pass parameters to `!copy()`, like in `~Cursor.execute()`:
+
+ .. code:: python
+
+ with cur.copy("COPY (SELECT * FROM table_name LIMIT %s) TO STDOUT", (3,)) as copy:
+ # expect no more than three records
+
+The connection is subject to the usual transaction behaviour, so, unless the
+connection is in autocommit, at the end of the COPY operation you will still
+have to commit the pending changes and you can still roll them back. See
+:ref:`transactions` for details.
+
+
+.. _copy-in-row:
+
+Writing data row-by-row
+-----------------------
+
+Using a copy operation you can load data into the database from any Python
+iterable (a list of tuples, or any iterable of sequences): the Python values
+are adapted as they would be in normal querying. To perform such operation use
+a :sql:`COPY ... FROM STDIN` with `Cursor.copy()` and use `~Copy.write_row()`
+on the resulting object in a `!with` block. On exiting the block the
+operation will be concluded:
+
+.. code:: python
+
+ records = [(10, 20, "hello"), (40, None, "world")]
+
+ with cursor.copy("COPY sample (col1, col2, col3) FROM STDIN") as copy:
+ for record in records:
+ copy.write_row(record)
+
+If an exception is raised inside the block, the operation is interrupted and
+the records inserted so far are discarded.
+
+In order to read or write from `!Copy` row-by-row you must not specify
+:sql:`COPY` options such as :sql:`FORMAT CSV`, :sql:`DELIMITER`, :sql:`NULL`:
+please leave these details alone, thank you :)
+
+
+.. _copy-out-row:
+
+Reading data row-by-row
+-----------------------
+
+You can also do the opposite, reading rows out of a :sql:`COPY ... TO STDOUT`
+operation, by iterating on `~Copy.rows()`. However this is not something you
+may want to do normally: usually the normal query process will be easier to
+use.
+
+PostgreSQL, currently, doesn't give complete type information on :sql:`COPY
+TO`, so the rows returned will have unparsed data, as strings or bytes,
+according to the format.
+
+.. code:: python
+
+ with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy:
+ for row in copy.rows():
+ print(row) # return unparsed data: ('10', '2046-12-24')
+
+You can improve the results by using `~Copy.set_types()` before reading, but
+you have to specify them yourself.
+
+.. code:: python
+
+ with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy:
+ copy.set_types(["int4", "date"])
+ for row in copy.rows():
+ print(row) # (10, datetime.date(2046, 12, 24))
+
+
+.. _copy-block:
+
+Copying block-by-block
+----------------------
+
+If data is already formatted in a way suitable for copy (for instance because
+it is coming from a file resulting from a previous `COPY TO` operation) it can
+be loaded into the database using `Copy.write()` instead.
+
+.. code:: python
+
+ with open("data", "r") as f:
+ with cursor.copy("COPY data FROM STDIN") as copy:
+ while data := f.read(BLOCK_SIZE):
+ copy.write(data)
+
+In this case you can use any :sql:`COPY` option and format, as long as the
+input data is compatible with what the operation in `!copy()` expects. Data
+can be passed as `!str`, if the copy is in :sql:`FORMAT TEXT`, or as `!bytes`,
+which works with both :sql:`FORMAT TEXT` and :sql:`FORMAT BINARY`.
+
+In order to produce data in :sql:`COPY` format you can use a :sql:`COPY ... TO
+STDOUT` statement and iterate over the resulting `Copy` object, which will
+produce a stream of `!bytes` objects:
+
+.. code:: python
+
+ with open("data.out", "wb") as f:
+ with cursor.copy("COPY table_name TO STDOUT") as copy:
+ for data in copy:
+ f.write(data)
+
+
+.. _copy-binary:
+
+Binary copy
+-----------
+
+Binary copy is supported by specifying :sql:`FORMAT BINARY` in the :sql:`COPY`
+statement. In order to import binary data using `~Copy.write_row()`, all the
+types passed to the database must have a binary dumper registered; this is not
+necessary if the data is copied :ref:`block-by-block <copy-block>` using
+`~Copy.write()`.
+
+.. warning::
+
+ PostgreSQL is particularly finicky when loading data in binary mode and
+ will apply **no cast rules**. This means, for example, that passing the
+ value 100 to an `integer` column **will fail**, because Psycopg will pass
+ it as a `smallint` value, and the server will reject it because its size
+ doesn't match what expected.
+
+ You can work around the problem using the `~Copy.set_types()` method of
+ the `!Copy` object and specifying carefully the types to load.
+
+.. seealso:: See :ref:`binary-data` for further info about binary querying.
+
+
+.. _copy-async:
+
+Asynchronous copy support
+-------------------------
+
+Asynchronous operations are supported using the same patterns as above, using
+the objects obtained by an `AsyncConnection`. For instance, if `!f` is an
+object supporting an asynchronous `!read()` method returning :sql:`COPY` data,
+a fully-async copy operation could be:
+
+.. code:: python
+
+ async with cursor.copy("COPY data FROM STDIN") as copy:
+ while data := await f.read():
+ await copy.write(data)
+
+The `AsyncCopy` object documentation describes the signature of the
+asynchronous methods and the differences from its sync `Copy` counterpart.
+
+.. seealso:: See :ref:`async` for further info about using async objects.
+
+
+Example: copying a table across servers
+---------------------------------------
+
+In order to copy a table, or a portion of a table, across servers, you can use
+two COPY operations on two different connections, reading from the first and
+writing to the second.
+
+.. code:: python
+
+ with psycopg.connect(dsn_src) as conn1, psycopg.connect(dsn_tgt) as conn2:
+ with conn1.cursor().copy("COPY src TO STDOUT (FORMAT BINARY)") as copy1:
+ with conn2.cursor().copy("COPY tgt FROM STDIN (FORMAT BINARY)") as copy2:
+ for data in copy1:
+ copy2.write(data)
+
+Using :sql:`FORMAT BINARY` usually gives a performance boost, but it only
+works if the source and target schema are *perfectly identical*. If the tables
+are only *compatible* (for example, if you are copying an :sql:`integer` field
+into a :sql:`bigint` destination field) you should omit the `BINARY` option and
+perform a text-based copy. See :ref:`copy-binary` for details.
+
+The same pattern can be adapted to use :ref:`async objects <async>` in order
+to perform an :ref:`async copy <copy-async>`.
diff --git a/docs/basic/from_pg2.rst b/docs/basic/from_pg2.rst
new file mode 100644
index 0000000..0692049
--- /dev/null
+++ b/docs/basic/from_pg2.rst
@@ -0,0 +1,359 @@
+.. index::
+ pair: psycopg2; Differences
+
+.. currentmodule:: psycopg
+
+.. _from-psycopg2:
+
+
+Differences from `!psycopg2`
+============================
+
+Psycopg 3 uses the common DBAPI structure of many other database adapters and
+tries to behave as close as possible to `!psycopg2`. There are however a few
+differences to be aware of.
+
+.. tip::
+ Most of the times, the workarounds suggested here will work with both
+ Psycopg 2 and 3, which could be useful if you are porting a program or
+ writing a program that should work with both Psycopg 2 and 3.
+
+
+.. _server-side-binding:
+
+Server-side binding
+-------------------
+
+Psycopg 3 sends the query and the parameters to the server separately, instead
+of merging them on the client side. Server-side binding works for normal
+:sql:`SELECT` and data manipulation statements (:sql:`INSERT`, :sql:`UPDATE`,
+:sql:`DELETE`), but it doesn't work with many other statements. For instance,
+it doesn't work with :sql:`SET` or with :sql:`NOTIFY`::
+
+ >>> conn.execute("SET TimeZone TO %s", ["UTC"])
+ Traceback (most recent call last):
+ ...
+ psycopg.errors.SyntaxError: syntax error at or near "$1"
+ LINE 1: SET TimeZone TO $1
+ ^
+
+ >>> conn.execute("NOTIFY %s, %s", ["chan", 42])
+ Traceback (most recent call last):
+ ...
+ psycopg.errors.SyntaxError: syntax error at or near "$1"
+ LINE 1: NOTIFY $1, $2
+ ^
+
+and with any data definition statement::
+
+ >>> conn.execute("CREATE TABLE foo (id int DEFAULT %s)", [42])
+ Traceback (most recent call last):
+ ...
+ psycopg.errors.UndefinedParameter: there is no parameter $1
+ LINE 1: CREATE TABLE foo (id int DEFAULT $1)
+ ^
+
+Sometimes, PostgreSQL offers an alternative: for instance the `set_config()`__
+function can be used instead of the :sql:`SET` statement, the `pg_notify()`__
+function can be used instead of :sql:`NOTIFY`::
+
+ >>> conn.execute("SELECT set_config('TimeZone', %s, false)", ["UTC"])
+
+ >>> conn.execute("SELECT pg_notify(%s, %s)", ["chan", "42"])
+
+.. __: https://www.postgresql.org/docs/current/functions-admin.html
+ #FUNCTIONS-ADMIN-SET
+
+.. __: https://www.postgresql.org/docs/current/sql-notify.html
+ #id-1.9.3.157.7.5
+
+If this is not possible, you must merge the query and the parameter on the
+client side. You can do so using the `psycopg.sql` objects::
+
+ >>> from psycopg import sql
+
+ >>> cur.execute(sql.SQL("CREATE TABLE foo (id int DEFAULT {})").format(42))
+
+or creating a :ref:`client-side binding cursor <client-side-binding-cursors>`
+such as `ClientCursor`::
+
+ >>> cur = ClientCursor(conn)
+ >>> cur.execute("CREATE TABLE foo (id int DEFAULT %s)", [42])
+
+If you need `!ClientCursor` often, you can set the `Connection.cursor_factory`
+to have them created by default by `Connection.cursor()`. This way, Psycopg 3
+will behave largely the same way of Psycopg 2.
+
+Note that, both server-side and client-side, you can only specify **values**
+as parameters (i.e. *the strings that go in single quotes*). If you need to
+parametrize different parts of a statement (such as a table name), you must
+use the `psycopg.sql` module::
+
+ >>> from psycopg import sql
+
+ # This will quote the user and the password using the right quotes
+ # e.g.: ALTER USER "foo" SET PASSWORD 'bar'
+ >>> conn.execute(
+ ... sql.SQL("ALTER USER {} SET PASSWORD {}")
+ ... .format(sql.Identifier(username), password))
+
+
+.. _multi-statements:
+
+Multiple statements in the same query
+-------------------------------------
+
+As a consequence of using :ref:`server-side bindings <server-side-binding>`,
+when parameters are used, it is not possible to execute several statements in
+the same `!execute()` call, separating them by semicolon::
+
+ >>> conn.execute(
+ ... "INSERT INTO foo VALUES (%s); INSERT INTO foo VALUES (%s)",
+ ... (10, 20))
+ Traceback (most recent call last):
+ ...
+ psycopg.errors.SyntaxError: cannot insert multiple commands into a prepared statement
+
+One obvious way to work around the problem is to use several `!execute()`
+calls.
+
+**There is no such limitation if no parameters are used**. As a consequence, you
+can compose a multiple query on the client side and run them all in the same
+`!execute()` call, using the `psycopg.sql` objects::
+
+ >>> from psycopg import sql
+ >>> conn.execute(
+ ... sql.SQL("INSERT INTO foo VALUES ({}); INSERT INTO foo values ({})"
+ ... .format(10, 20))
+
+or a :ref:`client-side binding cursor <client-side-binding-cursors>`::
+
+ >>> cur = psycopg.ClientCursor(conn)
+ >>> cur.execute(
+ ... "INSERT INTO foo VALUES (%s); INSERT INTO foo VALUES (%s)",
+ ... (10, 20))
+
+.. warning::
+
+ If a statements must be executed outside a transaction (such as
+ :sql:`CREATE DATABASE`), it cannot be executed in batch with other
+ statements, even if the connection is in autocommit mode::
+
+ >>> conn.autocommit = True
+ >>> conn.execute("CREATE DATABASE foo; SELECT 1")
+ Traceback (most recent call last):
+ ...
+ psycopg.errors.ActiveSqlTransaction: CREATE DATABASE cannot run inside a transaction block
+
+ This happens because PostgreSQL itself will wrap multiple statements in a
+ transaction. Note that your will experience a different behaviour in
+ :program:`psql` (:program:`psql` will split the queries on semicolons and
+ send them to the server separately).
+
+ This is not new in Psycopg 3: the same limitation is present in
+ `!psycopg2` too.
+
+
+.. _multi-results:
+
+Multiple results returned from multiple statements
+--------------------------------------------------
+
+If more than one statement returning results is executed in psycopg2, only the
+result of the last statement is returned::
+
+ >>> cur_pg2.execute("SELECT 1; SELECT 2")
+ >>> cur_pg2.fetchone()
+ (2,)
+
+In Psycopg 3 instead, all the results are available. After running the query,
+the first result will be readily available in the cursor and can be consumed
+using the usual `!fetch*()` methods. In order to access the following
+results, you can use the `Cursor.nextset()` method::
+
+ >>> cur_pg3.execute("SELECT 1; SELECT 2")
+ >>> cur_pg3.fetchone()
+ (1,)
+
+ >>> cur_pg3.nextset()
+ True
+ >>> cur_pg3.fetchone()
+ (2,)
+
+ >>> cur_pg3.nextset()
+ None # no more results
+
+Remember though that you cannot use server-side bindings to :ref:`execute more
+than one statement in the same query <multi-statements>`.
+
+
+.. _difference-cast-rules:
+
+Different cast rules
+--------------------
+
+In rare cases, especially around variadic functions, PostgreSQL might fail to
+find a function candidate for the given data types::
+
+ >>> conn.execute("SELECT json_build_array(%s, %s)", ["foo", "bar"])
+ Traceback (most recent call last):
+ ...
+ psycopg.errors.IndeterminateDatatype: could not determine data type of parameter $1
+
+This can be worked around specifying the argument types explicitly via a cast::
+
+ >>> conn.execute("SELECT json_build_array(%s::text, %s::text)", ["foo", "bar"])
+
+
+.. _in-and-tuple:
+
+You cannot use ``IN %s`` with a tuple
+-------------------------------------
+
+``IN`` cannot be used with a tuple as single parameter, as was possible with
+``psycopg2``::
+
+ >>> conn.execute("SELECT * FROM foo WHERE id IN %s", [(10,20,30)])
+ Traceback (most recent call last):
+ ...
+ psycopg.errors.SyntaxError: syntax error at or near "$1"
+ LINE 1: SELECT * FROM foo WHERE id IN $1
+ ^
+
+What you can do is to use the `= ANY()`__ construct and pass the candidate
+values as a list instead of a tuple, which will be adapted to a PostgreSQL
+array::
+
+ >>> conn.execute("SELECT * FROM foo WHERE id = ANY(%s)", [[10,20,30]])
+
+Note that `ANY()` can be used with `!psycopg2` too, and has the advantage of
+accepting an empty list of values too as argument, which is not supported by
+the :sql:`IN` operator instead.
+
+.. __: https://www.postgresql.org/docs/current/functions-comparisons.html
+ #id-1.5.8.30.16
+
+
+.. _diff-adapt:
+
+Different adaptation system
+---------------------------
+
+The adaptation system has been completely rewritten, in order to address
+server-side parameters adaptation, but also to consider performance,
+flexibility, ease of customization.
+
+The default behaviour with builtin data should be :ref:`what you would expect
+<types-adaptation>`. If you have customised the way to adapt data, or if you
+are managing your own extension types, you should look at the :ref:`new
+adaptation system <adaptation>`.
+
+.. seealso::
+
+ - :ref:`types-adaptation` for the basic behaviour.
+ - :ref:`adaptation` for more advanced use.
+
+
+.. _diff-copy:
+
+Copy is no longer file-based
+----------------------------
+
+`!psycopg2` exposes :ref:`a few copy methods <pg2:copy>` to interact with
+PostgreSQL :sql:`COPY`. Their file-based interface doesn't make it easy to load
+dynamically-generated data into a database.
+
+There is now a single `~Cursor.copy()` method, which is similar to
+`!psycopg2` `!copy_expert()` in accepting a free-form :sql:`COPY` command and
+returns an object to read/write data, block-wise or record-wise. The different
+usage pattern also enables :sql:`COPY` to be used in async interactions.
+
+.. seealso:: See :ref:`copy` for the details.
+
+
+.. _diff-with:
+
+`!with` connection
+------------------
+
+In `!psycopg2`, using the syntax :ref:`with connection <pg2:with>`,
+only the transaction is closed, not the connection. This behaviour is
+surprising for people used to several other Python classes wrapping resources,
+such as files.
+
+In Psycopg 3, using :ref:`with connection <with-connection>` will close the
+connection at the end of the `!with` block, making handling the connection
+resources more familiar.
+
+In order to manage transactions as blocks you can use the
+`Connection.transaction()` method, which allows for finer control, for
+instance to use nested transactions.
+
+.. seealso:: See :ref:`transaction-context` for details.
+
+
+.. _diff-callproc:
+
+`!callproc()` is gone
+---------------------
+
+`cursor.callproc()` is not implemented. The method has a simplistic semantic
+which doesn't account for PostgreSQL positional parameters, procedures,
+set-returning functions... Use a normal `~Cursor.execute()` with :sql:`SELECT
+function_name(...)` or :sql:`CALL procedure_name(...)` instead.
+
+
+.. _diff-client-encoding:
+
+`!client_encoding` is gone
+--------------------------
+
+Psycopg automatically uses the database client encoding to decode data to
+Unicode strings. Use `ConnectionInfo.encoding` if you need to read the
+encoding. You can select an encoding at connection time using the
+`!client_encoding` connection parameter and you can change the encoding of a
+connection by running a :sql:`SET client_encoding` statement... But why would
+you?
+
+
+.. _infinity-datetime:
+
+No default infinity dates handling
+----------------------------------
+
+PostgreSQL can represent a much wider range of dates and timestamps than
+Python. While Python dates are limited to the years between 1 and 9999
+(represented by constants such as `datetime.date.min` and
+`~datetime.date.max`), PostgreSQL dates extend to BC dates and past the year
+10K. Furthermore PostgreSQL can also represent symbolic dates "infinity", in
+both directions.
+
+In psycopg2, by default, `infinity dates and timestamps map to 'date.max'`__
+and similar constants. This has the problem of creating a non-bijective
+mapping (two Postgres dates, infinity and 9999-12-31, both map to the same
+Python date). There is also the perversity that valid Postgres dates, greater
+than Python `!date.max` but arguably lesser than infinity, will still
+overflow.
+
+In Psycopg 3, every date greater than year 9999 will overflow, including
+infinity. If you would like to customize this mapping (for instance flattening
+every date past Y10K on `!date.max`) you can subclass and adapt the
+appropriate loaders: take a look at :ref:`this example
+<adapt-example-inf-date>` to see how.
+
+.. __: https://www.psycopg.org/docs/usage.html#infinite-dates-handling
+
+
+.. _whats-new:
+
+What's new in Psycopg 3
+-----------------------
+
+- :ref:`Asynchronous support <async>`
+- :ref:`Server-side parameters binding <server-side-binding>`
+- :ref:`Prepared statements <prepared-statements>`
+- :ref:`Binary communication <binary-data>`
+- :ref:`Python-based COPY support <copy>`
+- :ref:`Support for static typing <static-typing>`
+- :ref:`A redesigned connection pool <connection-pools>`
+- :ref:`Direct access to the libpq functionalities <psycopg.pq>`
diff --git a/docs/basic/index.rst b/docs/basic/index.rst
new file mode 100644
index 0000000..bf9e27d
--- /dev/null
+++ b/docs/basic/index.rst
@@ -0,0 +1,26 @@
+.. _basic:
+
+Getting started with Psycopg 3
+==============================
+
+This section of the documentation will explain :ref:`how to install Psycopg
+<installation>` and how to perform normal activities such as :ref:`querying
+the database <usage>` or :ref:`loading data using COPY <copy>`.
+
+.. important::
+
+ If you are familiar with psycopg2 please take a look at
+ :ref:`from-psycopg2` to see what is changed.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+ install
+ usage
+ params
+ adapt
+ pgtypes
+ transactions
+ copy
+ from_pg2
diff --git a/docs/basic/install.rst b/docs/basic/install.rst
new file mode 100644
index 0000000..8e1dc6d
--- /dev/null
+++ b/docs/basic/install.rst
@@ -0,0 +1,172 @@
+.. _installation:
+
+Installation
+============
+
+In short, if you use a :ref:`supported system<supported-systems>`::
+
+ pip install --upgrade pip # upgrade pip to at least 20.3
+ pip install "psycopg[binary]"
+
+and you should be :ref:`ready to start <module-usage>`. Read further for
+alternative ways to install.
+
+
+.. _supported-systems:
+
+Supported systems
+-----------------
+
+The Psycopg version documented here has *official and tested* support for:
+
+- Python: from version 3.7 to 3.11
+
+ - Python 3.6 supported before Psycopg 3.1
+
+- PostgreSQL: from version 10 to 15
+- OS: Linux, macOS, Windows
+
+The tests to verify the supported systems run in `Github workflows`__:
+anything that is not tested there is not officially supported. This includes:
+
+.. __: https://github.com/psycopg/psycopg/actions
+
+- Unofficial Python distributions such as Conda;
+- Alternative PostgreSQL implementation;
+- macOS hardware and releases not available on Github workflows.
+
+If you use an unsupported system, things might work (because, for instance, the
+database may use the same wire protocol as PostgreSQL) but we cannot guarantee
+the correct working or a smooth ride.
+
+
+.. _binary-install:
+
+Binary installation
+-------------------
+
+The quickest way to start developing with Psycopg 3 is to install the binary
+packages by running::
+
+ pip install "psycopg[binary]"
+
+This will install a self-contained package with all the libraries needed.
+**You will need pip 20.3 at least**: please run ``pip install --upgrade pip``
+to update it beforehand.
+
+The above package should work in most situations. It **will not work** in
+some cases though.
+
+If your platform is not supported you should proceed to a :ref:`local
+installation <local-installation>` or a :ref:`pure Python installation
+<pure-python-installation>`.
+
+.. seealso::
+
+ Did Psycopg 3 install ok? Great! You can now move on to the :ref:`basic
+ module usage <module-usage>` to learn how it works.
+
+ Keep on reading if the above method didn't work and you need a different
+ way to install Psycopg 3.
+
+ For further information about the differences between the packages see
+ :ref:`pq-impl`.
+
+
+.. _local-installation:
+
+Local installation
+------------------
+
+A "Local installation" results in a performing and maintainable library. The
+library will include the speed-up C module and will be linked to the system
+libraries (``libpq``, ``libssl``...) so that system upgrade of libraries will
+upgrade the libraries used by Psycopg 3 too. This is the preferred way to
+install Psycopg for a production site.
+
+In order to perform a local installation you need some prerequisites:
+
+- a C compiler,
+- Python development headers (e.g. the ``python3-dev`` package).
+- PostgreSQL client development headers (e.g. the ``libpq-dev`` package).
+- The :program:`pg_config` program available in the :envvar:`PATH`.
+
+You **must be able** to troubleshoot an extension build, for instance you must
+be able to read your compiler's error message. If you are not, please don't
+try this and follow the `binary installation`_ instead.
+
+If your build prerequisites are in place you can run::
+
+ pip install "psycopg[c]"
+
+
+.. _pure-python-installation:
+
+Pure Python installation
+------------------------
+
+If you simply install::
+
+ pip install psycopg
+
+without ``[c]`` or ``[binary]`` extras you will obtain a pure Python
+implementation. This is particularly handy to debug and hack, but it still
+requires the system libpq to operate (which will be imported dynamically via
+`ctypes`).
+
+In order to use the pure Python installation you will need the ``libpq``
+installed in the system: for instance on Debian system you will probably
+need::
+
+ sudo apt install libpq5
+
+.. note::
+
+ The ``libpq`` is the client library used by :program:`psql`, the
+ PostgreSQL command line client, to connect to the database. On most
+ systems, installing :program:`psql` will install the ``libpq`` too as a
+ dependency.
+
+If you are not able to fulfill this requirement please follow the `binary
+installation`_.
+
+
+.. _pool-installation:
+
+Installing the connection pool
+------------------------------
+
+The :ref:`Psycopg connection pools <connection-pools>` are distributed in a
+separate package from the `!psycopg` package itself, in order to allow a
+different release cycle.
+
+In order to use the pool you must install the ``pool`` extra, using ``pip
+install "psycopg[pool]"``, or install the `psycopg_pool` package separately,
+which would allow to specify the release to install more precisely.
+
+
+Handling dependencies
+---------------------
+
+If you need to specify your project dependencies (for instance in a
+``requirements.txt`` file, ``setup.py``, ``pyproject.toml`` dependencies...)
+you should probably specify one of the following:
+
+- If your project is a library, add a dependency on ``psycopg``. This will
+ make sure that your library will have the ``psycopg`` package with the right
+ interface and leaves the possibility of choosing a specific implementation
+ to the end user of your library.
+
+- If your project is a final application (e.g. a service running on a server)
+ you can require a specific implementation, for instance ``psycopg[c]``,
+ after you have made sure that the prerequisites are met (e.g. the depending
+ libraries and tools are installed in the host machine).
+
+In both cases you can specify which version of Psycopg to use using
+`requirement specifiers`__.
+
+.. __: https://pip.pypa.io/en/stable/cli/pip_install/#requirement-specifiers
+
+If you want to make sure that a specific implementation is used you can
+specify the :envvar:`PSYCOPG_IMPL` environment variable: importing the library
+will fail if the implementation specified is not available. See :ref:`pq-impl`.
diff --git a/docs/basic/params.rst b/docs/basic/params.rst
new file mode 100644
index 0000000..a733f07
--- /dev/null
+++ b/docs/basic/params.rst
@@ -0,0 +1,242 @@
+.. currentmodule:: psycopg
+
+.. index::
+ pair: Query; Parameters
+
+.. _query-parameters:
+
+Passing parameters to SQL queries
+=================================
+
+Most of the times, writing a program you will have to mix bits of SQL
+statements with values provided by the rest of the program:
+
+.. code::
+
+ SELECT some, fields FROM some_table WHERE id = ...
+
+:sql:`id` equals what? Probably you will have a Python value you are looking
+for.
+
+
+`!execute()` arguments
+----------------------
+
+Passing parameters to a SQL statement happens in functions such as
+`Cursor.execute()` by using ``%s`` placeholders in the SQL statement, and
+passing a sequence of values as the second argument of the function. For
+example the Python function call:
+
+.. code:: python
+
+ cur.execute("""
+ INSERT INTO some_table (id, created_at, last_name)
+ VALUES (%s, %s, %s);
+ """,
+ (10, datetime.date(2020, 11, 18), "O'Reilly"))
+
+is *roughly* equivalent to the SQL command:
+
+.. code-block:: sql
+
+ INSERT INTO some_table (id, created_at, last_name)
+ VALUES (10, '2020-11-18', 'O''Reilly');
+
+Note that the parameters will not be really merged to the query: query and the
+parameters are sent to the server separately: see :ref:`server-side-binding`
+for details.
+
+Named arguments are supported too using :samp:`%({name})s` placeholders in the
+query and specifying the values into a mapping. Using named arguments allows
+to specify the values in any order and to repeat the same value in several
+places in the query::
+
+ cur.execute("""
+ INSERT INTO some_table (id, created_at, updated_at, last_name)
+ VALUES (%(id)s, %(created)s, %(created)s, %(name)s);
+ """,
+ {'id': 10, 'name': "O'Reilly", 'created': datetime.date(2020, 11, 18)})
+
+Using characters ``%``, ``(``, ``)`` in the argument names is not supported.
+
+When parameters are used, in order to include a literal ``%`` in the query you
+can use the ``%%`` string::
+
+ cur.execute("SELECT (%s % 2) = 0 AS even", (10,)) # WRONG
+ cur.execute("SELECT (%s %% 2) = 0 AS even", (10,)) # correct
+
+While the mechanism resembles regular Python strings manipulation, there are a
+few subtle differences you should care about when passing parameters to a
+query.
+
+- The Python string operator ``%`` *must not be used*: the `~cursor.execute()`
+ method accepts a tuple or dictionary of values as second parameter.
+ |sql-warn|__:
+
+ .. |sql-warn| replace:: **Never** use ``%`` or ``+`` to merge values
+ into queries
+
+ .. code:: python
+
+ cur.execute("INSERT INTO numbers VALUES (%s, %s)" % (10, 20)) # WRONG
+ cur.execute("INSERT INTO numbers VALUES (%s, %s)", (10, 20)) # correct
+
+ .. __: sql-injection_
+
+- For positional variables binding, *the second argument must always be a
+ sequence*, even if it contains a single variable (remember that Python
+ requires a comma to create a single element tuple)::
+
+ cur.execute("INSERT INTO foo VALUES (%s)", "bar") # WRONG
+ cur.execute("INSERT INTO foo VALUES (%s)", ("bar")) # WRONG
+ cur.execute("INSERT INTO foo VALUES (%s)", ("bar",)) # correct
+ cur.execute("INSERT INTO foo VALUES (%s)", ["bar"]) # correct
+
+- The placeholder *must not be quoted*::
+
+ cur.execute("INSERT INTO numbers VALUES ('%s')", ("Hello",)) # WRONG
+ cur.execute("INSERT INTO numbers VALUES (%s)", ("Hello",)) # correct
+
+- The variables placeholder *must always be a* ``%s``, even if a different
+ placeholder (such as a ``%d`` for integers or ``%f`` for floats) may look
+ more appropriate for the type. You may find other placeholders used in
+ Psycopg queries (``%b`` and ``%t``) but they are not related to the
+ type of the argument: see :ref:`binary-data` if you want to read more::
+
+ cur.execute("INSERT INTO numbers VALUES (%d)", (10,)) # WRONG
+ cur.execute("INSERT INTO numbers VALUES (%s)", (10,)) # correct
+
+- Only query values should be bound via this method: it shouldn't be used to
+ merge table or field names to the query. If you need to generate SQL queries
+ dynamically (for instance choosing a table name at runtime) you can use the
+ functionalities provided in the `psycopg.sql` module::
+
+ cur.execute("INSERT INTO %s VALUES (%s)", ('numbers', 10)) # WRONG
+ cur.execute( # correct
+ SQL("INSERT INTO {} VALUES (%s)").format(Identifier('numbers')),
+ (10,))
+
+
+.. index:: Security, SQL injection
+
+.. _sql-injection:
+
+Danger: SQL injection
+---------------------
+
+The SQL representation of many data types is often different from their Python
+string representation. The typical example is with single quotes in strings:
+in SQL single quotes are used as string literal delimiters, so the ones
+appearing inside the string itself must be escaped, whereas in Python single
+quotes can be left unescaped if the string is delimited by double quotes.
+
+Because of the difference, sometimes subtle, between the data types
+representations, a naïve approach to query strings composition, such as using
+Python strings concatenation, is a recipe for *terrible* problems::
+
+ SQL = "INSERT INTO authors (name) VALUES ('%s')" # NEVER DO THIS
+ data = ("O'Reilly", )
+ cur.execute(SQL % data) # THIS WILL FAIL MISERABLY
+ # SyntaxError: syntax error at or near "Reilly"
+
+If the variables containing the data to send to the database come from an
+untrusted source (such as data coming from a form on a web site) an attacker
+could easily craft a malformed string, either gaining access to unauthorized
+data or performing destructive operations on the database. This form of attack
+is called `SQL injection`_ and is known to be one of the most widespread forms
+of attack on database systems. Before continuing, please print `this page`__
+as a memo and hang it onto your desk.
+
+.. _SQL injection: https://en.wikipedia.org/wiki/SQL_injection
+.. __: https://xkcd.com/327/
+
+Psycopg can :ref:`automatically convert Python objects to SQL
+values<types-adaptation>`: using this feature your code will be more robust
+and reliable. We must stress this point:
+
+.. warning::
+
+ - Don't manually merge values to a query: hackers from a foreign country
+ will break into your computer and steal not only your disks, but also
+ your cds, leaving you only with the three most embarrassing records you
+ ever bought. On cassette tapes.
+
+ - If you use the ``%`` operator to merge values to a query, con artists
+ will seduce your cat, who will run away taking your credit card
+ and your sunglasses with them.
+
+ - If you use ``+`` to merge a textual value to a string, bad guys in
+ balaclava will find their way to your fridge, drink all your beer, and
+ leave your toilet seat up and your toilet paper in the wrong orientation.
+
+ - You don't want to manually merge values to a query: :ref:`use the
+ provided methods <query-parameters>` instead.
+
+The correct way to pass variables in a SQL command is using the second
+argument of the `Cursor.execute()` method::
+
+ SQL = "INSERT INTO authors (name) VALUES (%s)" # Note: no quotes
+ data = ("O'Reilly", )
+ cur.execute(SQL, data) # Note: no % operator
+
+.. note::
+
+ Python static code checkers are not quite there yet, but, in the future,
+ it will be possible to check your code for improper use of string
+ expressions in queries. See :ref:`literal-string` for details.
+
+.. seealso::
+
+ Now that you know how to pass parameters to queries, you can take a look
+ at :ref:`how Psycopg converts data types <types-adaptation>`.
+
+
+.. index::
+ pair: Binary; Parameters
+
+.. _binary-data:
+
+Binary parameters and results
+-----------------------------
+
+PostgreSQL has two different ways to transmit data between client and server:
+`~psycopg.pq.Format.TEXT`, always available, and `~psycopg.pq.Format.BINARY`,
+available most of the times but not always. Usually the binary format is more
+efficient to use.
+
+Psycopg can support both formats for each data type. Whenever a value
+is passed to a query using the normal ``%s`` placeholder, the best format
+available is chosen (often, but not always, the binary format is picked as the
+best choice).
+
+If you have a reason to select explicitly the binary format or the text format
+for a value you can use respectively a ``%b`` placeholder or a ``%t``
+placeholder instead of the normal ``%s``. `~Cursor.execute()` will fail if a
+`~psycopg.adapt.Dumper` for the right data type and format is not available.
+
+The same two formats, text or binary, are used by PostgreSQL to return data
+from a query to the client. Unlike with parameters, where you can choose the
+format value-by-value, all the columns returned by a query will have the same
+format. Every type returned by the query should have a `~psycopg.adapt.Loader`
+configured, otherwise the data will be returned as unparsed `!str` (for text
+results) or buffer (for binary results).
+
+.. note::
+ The `pg_type`_ table defines which format is supported for each PostgreSQL
+ data type. Text input/output is managed by the functions declared in the
+ ``typinput`` and ``typoutput`` fields (always present), binary
+ input/output is managed by the ``typsend`` and ``typreceive`` (which are
+ optional).
+
+ .. _pg_type: https://www.postgresql.org/docs/current/catalog-pg-type.html
+
+Because not every PostgreSQL type supports binary output, by default, the data
+will be returned in text format. In order to return data in binary format you
+can create the cursor using `Connection.cursor`\ `!(binary=True)` or execute
+the query using `Cursor.execute`\ `!(binary=True)`. A case in which
+requesting binary results is a clear winner is when you have large binary data
+in the database, such as images::
+
+ cur.execute(
+ "SELECT image_data FROM images WHERE id = %s", [image_id], binary=True)
+ data = cur.fetchone()[0]
diff --git a/docs/basic/pgtypes.rst b/docs/basic/pgtypes.rst
new file mode 100644
index 0000000..14ee5be
--- /dev/null
+++ b/docs/basic/pgtypes.rst
@@ -0,0 +1,389 @@
+.. currentmodule:: psycopg
+
+.. index::
+ single: Adaptation
+ pair: Objects; Adaptation
+ single: Data types; Adaptation
+
+.. _extra-adaptation:
+
+Adapting other PostgreSQL types
+===============================
+
+PostgreSQL offers other data types which don't map to native Python types.
+Psycopg offers wrappers and conversion functions to allow their use.
+
+
+.. index::
+ pair: Composite types; Data types
+ pair: tuple; Adaptation
+ pair: namedtuple; Adaptation
+
+.. _adapt-composite:
+
+Composite types casting
+-----------------------
+
+Psycopg can adapt PostgreSQL composite types (either created with the |CREATE
+TYPE|_ command or implicitly defined after a table row type) to and from
+Python tuples, `~collections.namedtuple`, or any other suitable object
+configured.
+
+.. |CREATE TYPE| replace:: :sql:`CREATE TYPE`
+.. _CREATE TYPE: https://www.postgresql.org/docs/current/static/sql-createtype.html
+
+Before using a composite type it is necessary to get information about it
+using the `~psycopg.types.composite.CompositeInfo` class and to register it
+using `~psycopg.types.composite.register_composite()`.
+
+.. autoclass:: psycopg.types.composite.CompositeInfo
+
+ `!CompositeInfo` is a `~psycopg.types.TypeInfo` subclass: check its
+ documentation for the generic usage, especially the
+ `~psycopg.types.TypeInfo.fetch()` method.
+
+ .. attribute:: python_type
+
+ After `register_composite()` is called, it will contain the python type
+ mapping to the registered composite.
+
+.. autofunction:: psycopg.types.composite.register_composite
+
+ After registering, fetching data of the registered composite will invoke
+ `!factory` to create corresponding Python objects.
+
+ If no factory is specified, a `~collection.namedtuple` is created and used
+ to return data.
+
+ If the `!factory` is a type (and not a generic callable), then dumpers for
+ that type are created and registered too, so that passing objects of that
+ type to a query will adapt them to the registered type.
+
+Example::
+
+ >>> from psycopg.types.composite import CompositeInfo, register_composite
+
+ >>> conn.execute("CREATE TYPE card AS (value int, suit text)")
+
+ >>> info = CompositeInfo.fetch(conn, "card")
+ >>> register_composite(info, conn)
+
+ >>> my_card = info.python_type(8, "hearts")
+ >>> my_card
+ card(value=8, suit='hearts')
+
+ >>> conn.execute(
+ ... "SELECT pg_typeof(%(card)s), (%(card)s).suit", {"card": my_card}
+ ... ).fetchone()
+ ('card', 'hearts')
+
+ >>> conn.execute("SELECT (%s, %s)::card", [1, "spades"]).fetchone()[0]
+ card(value=1, suit='spades')
+
+
+Nested composite types are handled as expected, provided that the type of the
+composite components are registered as well::
+
+ >>> conn.execute("CREATE TYPE card_back AS (face card, back text)")
+
+ >>> info2 = CompositeInfo.fetch(conn, "card_back")
+ >>> register_composite(info2, conn)
+
+ >>> conn.execute("SELECT ((8, 'hearts'), 'blue')::card_back").fetchone()[0]
+ card_back(face=card(value=8, suit='hearts'), back='blue')
+
+
+.. index::
+ pair: range; Data types
+
+.. _adapt-range:
+
+Range adaptation
+----------------
+
+PostgreSQL `range types`__ are a family of data types representing a range of
+values between two elements. The type of the element is called the range
+*subtype*. PostgreSQL offers a few built-in range types and allows the
+definition of custom ones.
+
+.. __: https://www.postgresql.org/docs/current/rangetypes.html
+
+All the PostgreSQL range types are loaded as the `~psycopg.types.range.Range`
+Python type, which is a `~typing.Generic` type and can hold bounds of
+different types.
+
+.. autoclass:: psycopg.types.range.Range
+
+ This Python type is only used to pass and retrieve range values to and
+ from PostgreSQL and doesn't attempt to replicate the PostgreSQL range
+ features: it doesn't perform normalization and doesn't implement all the
+ operators__ supported by the database.
+
+ PostgreSQL will perform normalisation on `!Range` objects used as query
+ parameters, so, when they are fetched back, they will be found in the
+ normal form (for instance ranges on integers will have `[)` bounds).
+
+ .. __: https://www.postgresql.org/docs/current/static/functions-range.html#RANGE-OPERATORS-TABLE
+
+ `!Range` objects are immutable, hashable, and support the `!in` operator
+ (checking if an element is within the range). They can be tested for
+ equivalence. Empty ranges evaluate to `!False` in a boolean context,
+ nonempty ones evaluate to `!True`.
+
+ `!Range` objects have the following attributes:
+
+ .. autoattribute:: isempty
+ .. autoattribute:: lower
+ .. autoattribute:: upper
+ .. autoattribute:: lower_inc
+ .. autoattribute:: upper_inc
+ .. autoattribute:: lower_inf
+ .. autoattribute:: upper_inf
+
+The built-in range objects are adapted automatically: if a `!Range` objects
+contains `~datetime.date` bounds, it is dumped using the :sql:`daterange` OID,
+and of course :sql:`daterange` values are loaded back as `!Range[date]`.
+
+If you create your own range type you can use `~psycopg.types.range.RangeInfo`
+and `~psycopg.types.range.register_range()` to associate the range type with
+its subtype and make it work like the builtin ones.
+
+.. autoclass:: psycopg.types.range.RangeInfo
+
+ `!RangeInfo` is a `~psycopg.types.TypeInfo` subclass: check its
+ documentation for generic details, especially the
+ `~psycopg.types.TypeInfo.fetch()` method.
+
+.. autofunction:: psycopg.types.range.register_range
+
+Example::
+
+ >>> from psycopg.types.range import Range, RangeInfo, register_range
+
+ >>> conn.execute("CREATE TYPE strrange AS RANGE (SUBTYPE = text)")
+ >>> info = RangeInfo.fetch(conn, "strrange")
+ >>> register_range(info, conn)
+
+ >>> conn.execute("SELECT pg_typeof(%s)", [Range("a", "z")]).fetchone()[0]
+ 'strrange'
+
+ >>> conn.execute("SELECT '[a,z]'::strrange").fetchone()[0]
+ Range('a', 'z', '[]')
+
+
+.. index::
+ pair: range; Data types
+
+.. _adapt-multirange:
+
+Multirange adaptation
+---------------------
+
+Since PostgreSQL 14, every range type is associated with a multirange__, a
+type representing a disjoint set of ranges. A multirange is
+automatically available for every range, built-in and user-defined.
+
+.. __: https://www.postgresql.org/docs/current/rangetypes.html
+
+All the PostgreSQL range types are loaded as the
+`~psycopg.types.multirange.Multirange` Python type, which is a mutable
+sequence of `~psycopg.types.range.Range` elements.
+
+.. autoclass:: psycopg.types.multirange.Multirange
+
+ This Python type is only used to pass and retrieve multirange values to
+ and from PostgreSQL and doesn't attempt to replicate the PostgreSQL
+ multirange features: overlapping items are not merged, empty ranges are
+ not discarded, the items are not ordered, the behaviour of `multirange
+ operators`__ is not replicated in Python.
+
+ PostgreSQL will perform normalisation on `!Multirange` objects used as
+ query parameters, so, when they are fetched back, they will be found
+ ordered, with overlapping ranges merged, etc.
+
+ .. __: https://www.postgresql.org/docs/current/static/functions-range.html#MULTIRANGE-OPERATORS-TABLE
+
+ `!Multirange` objects are a `~collections.abc.MutableSequence` and are
+ totally ordered: they behave pretty much like a list of `!Range`. Like
+ Range, they are `~typing.Generic` on the subtype of their range, so you
+ can declare a variable to be `!Multirange[date]` and mypy will complain if
+ you try to add it a `Range[Decimal]`.
+
+Like for `~psycopg.types.range.Range`, built-in multirange objects are adapted
+automatically: if a `!Multirange` object contains `!Range` with
+`~datetime.date` bounds, it is dumped using the :sql:`datemultirange` OID, and
+:sql:`datemultirange` values are loaded back as `!Multirange[date]`.
+
+If you have created your own range type you can use
+`~psycopg.types.multirange.MultirangeInfo` and
+`~psycopg.types.multirange.register_multirange()` to associate the resulting
+multirange type with its subtype and make it work like the builtin ones.
+
+.. autoclass:: psycopg.types.multirange.MultirangeInfo
+
+ `!MultirangeInfo` is a `~psycopg.types.TypeInfo` subclass: check its
+ documentation for generic details, especially the
+ `~psycopg.types.TypeInfo.fetch()` method.
+
+.. autofunction:: psycopg.types.multirange.register_multirange
+
+Example::
+
+ >>> from psycopg.types.multirange import \
+ ... Multirange, MultirangeInfo, register_multirange
+ >>> from psycopg.types.range import Range
+
+ >>> conn.execute("CREATE TYPE strrange AS RANGE (SUBTYPE = text)")
+ >>> info = MultirangeInfo.fetch(conn, "strmultirange")
+ >>> register_multirange(info, conn)
+
+ >>> rec = conn.execute(
+ ... "SELECT pg_typeof(%(mr)s), %(mr)s",
+ ... {"mr": Multirange([Range("a", "q"), Range("l", "z")])}).fetchone()
+
+ >>> rec[0]
+ 'strmultirange'
+ >>> rec[1]
+ Multirange([Range('a', 'z', '[)')])
+
+
+.. index::
+ pair: hstore; Data types
+ pair: dict; Adaptation
+
+.. _adapt-hstore:
+
+Hstore adaptation
+-----------------
+
+The |hstore|_ data type is a key-value store embedded in PostgreSQL. It
+supports GiST or GIN indexes allowing search by keys or key/value pairs as
+well as regular BTree indexes for equality, uniqueness etc.
+
+.. |hstore| replace:: :sql:`hstore`
+.. _hstore: https://www.postgresql.org/docs/current/static/hstore.html
+
+Psycopg can convert Python `!dict` objects to and from |hstore| structures.
+Only dictionaries with string keys and values are supported. `!None` is also
+allowed as value but not as a key.
+
+In order to use the |hstore| data type it is necessary to load it in a
+database using:
+
+.. code:: none
+
+ =# CREATE EXTENSION hstore;
+
+Because |hstore| is distributed as a contrib module, its oid is not well
+known, so it is necessary to use `!TypeInfo`\.\
+`~psycopg.types.TypeInfo.fetch()` to query the database and get its oid. The
+resulting object can be passed to
+`~psycopg.types.hstore.register_hstore()` to configure dumping `!dict` to
+|hstore| and parsing |hstore| back to `!dict`, in the context where the
+adapter is registered.
+
+.. autofunction:: psycopg.types.hstore.register_hstore
+
+Example::
+
+ >>> from psycopg.types import TypeInfo
+ >>> from psycopg.types.hstore import register_hstore
+
+ >>> info = TypeInfo.fetch(conn, "hstore")
+ >>> register_hstore(info, conn)
+
+ >>> conn.execute("SELECT pg_typeof(%s)", [{"a": "b"}]).fetchone()[0]
+ 'hstore'
+
+ >>> conn.execute("SELECT 'foo => bar'::hstore").fetchone()[0]
+ {'foo': 'bar'}
+
+
+.. index::
+ pair: geometry; Data types
+ single: PostGIS; Data types
+
+.. _adapt-shapely:
+
+Geometry adaptation using Shapely
+---------------------------------
+
+When using the PostGIS_ extension, it can be useful to retrieve geometry_
+values and have them automatically converted to Shapely_ instances. Likewise,
+you may want to store such instances in the database and have the conversion
+happen automatically.
+
+.. warning::
+ Psycopg doesn't have a dependency on the ``shapely`` package: you should
+ install the library as an additional dependency of your project.
+
+.. warning::
+ This module is experimental and might be changed in the future according
+ to users' feedback.
+
+.. _PostGIS: https://postgis.net/
+.. _geometry: https://postgis.net/docs/geometry.html
+.. _Shapely: https://github.com/Toblerity/Shapely
+.. _shape: https://shapely.readthedocs.io/en/stable/manual.html#shapely.geometry.shape
+
+Since PostgGIS is an extension, the :sql:`geometry` type oid is not well
+known, so it is necessary to use `!TypeInfo`\.\
+`~psycopg.types.TypeInfo.fetch()` to query the database and find it. The
+resulting object can be passed to `~psycopg.types.shapely.register_shapely()`
+to configure dumping `shape`_ instances to :sql:`geometry` columns and parsing
+:sql:`geometry` data back to `!shape` instances, in the context where the
+adapters are registered.
+
+.. function:: psycopg.types.shapely.register_shapely
+
+ Register Shapely dumper and loaders.
+
+ After invoking this function on an adapter, the queries retrieving
+ PostGIS geometry objects will return Shapely's shape object instances
+ both in text and binary mode.
+
+ Similarly, shape objects can be sent to the database.
+
+ This requires the Shapely library to be installed.
+
+ :param info: The object with the information about the geometry type.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+
+Example::
+
+ >>> from psycopg.types import TypeInfo
+ >>> from psycopg.types.shapely import register_shapely
+ >>> from shapely.geometry import Point
+
+ >>> info = TypeInfo.fetch(conn, "geometry")
+ >>> register_shapely(info, conn)
+
+ >>> conn.execute("SELECT pg_typeof(%s)", [Point(1.2, 3.4)]).fetchone()[0]
+ 'geometry'
+
+ >>> conn.execute("""
+ ... SELECT ST_GeomFromGeoJSON('{
+ ... "type":"Point",
+ ... "coordinates":[-48.23456,20.12345]}')
+ ... """).fetchone()[0]
+ <shapely.geometry.multipolygon.MultiPolygon object at 0x7fb131f3cd90>
+
+Notice that, if the geometry adapters are registered on a specific object (a
+connection or cursor), other connections and cursors will be unaffected::
+
+ >>> conn2 = psycopg.connect(CONN_STR)
+ >>> conn2.execute("""
+ ... SELECT ST_GeomFromGeoJSON('{
+ ... "type":"Point",
+ ... "coordinates":[-48.23456,20.12345]}')
+ ... """).fetchone()[0]
+ '0101000020E61000009279E40F061E48C0F2B0506B9A1F3440'
+
diff --git a/docs/basic/transactions.rst b/docs/basic/transactions.rst
new file mode 100644
index 0000000..b976046
--- /dev/null
+++ b/docs/basic/transactions.rst
@@ -0,0 +1,388 @@
+.. currentmodule:: psycopg
+
+.. index:: Transactions management
+.. index:: InFailedSqlTransaction
+.. index:: idle in transaction
+
+.. _transactions:
+
+Transactions management
+=======================
+
+Psycopg has a behaviour that may seem surprising compared to
+:program:`psql`: by default, any database operation will start a new
+transaction. As a consequence, changes made by any cursor of the connection
+will not be visible until `Connection.commit()` is called, and will be
+discarded by `Connection.rollback()`. The following operation on the same
+connection will start a new transaction.
+
+If a database operation fails, the server will refuse further commands, until
+a `~rollback()` is called.
+
+If the cursor is closed with a transaction open, no COMMIT command is sent to
+the server, which will then discard the connection. Certain middleware (such
+as PgBouncer) will also discard a connection left in transaction state, so, if
+possible you will want to commit or rollback a connection before finishing
+working with it.
+
+An example of what will happen, the first time you will use Psycopg (and to be
+disappointed by it), is likely:
+
+.. code:: python
+
+ conn = psycopg.connect()
+
+ # Creating a cursor doesn't start a transaction or affect the connection
+ # in any way.
+ cur = conn.cursor()
+
+ cur.execute("SELECT count(*) FROM my_table")
+ # This function call executes:
+ # - BEGIN
+ # - SELECT count(*) FROM my_table
+ # So now a transaction has started.
+
+ # If your program spends a long time in this state, the server will keep
+ # a connection "idle in transaction", which is likely something undesired
+
+ cur.execute("INSERT INTO data VALUES (%s)", ("Hello",))
+ # This statement is executed inside the transaction
+
+ conn.close()
+ # No COMMIT was sent: the INSERT was discarded.
+
+There are a few things going wrong here, let's see how they can be improved.
+
+One obvious problem after the run above is that, firing up :program:`psql`,
+you will see no new record in the table ``data``. One way to fix the problem
+is to call `!conn.commit()` before closing the connection. Thankfully, if you
+use the :ref:`connection context <with-connection>`, Psycopg will commit the
+connection at the end of the block (or roll it back if the block is exited
+with an exception):
+
+The code modified using a connection context will result in the following
+sequence of database statements:
+
+.. code-block:: python
+ :emphasize-lines: 1
+
+ with psycopg.connect() as conn:
+
+ cur = conn.cursor()
+
+ cur.execute("SELECT count(*) FROM my_table")
+ # This function call executes:
+ # - BEGIN
+ # - SELECT count(*) FROM my_table
+ # So now a transaction has started.
+
+ cur.execute("INSERT INTO data VALUES (%s)", ("Hello",))
+ # This statement is executed inside the transaction
+
+ # No exception at the end of the block:
+ # COMMIT is executed.
+
+This way we don't have to remember to call neither `!close()` nor `!commit()`
+and the database operations actually have a persistent effect. The code might
+still do something you don't expect: keep a transaction from the first
+operation to the connection closure. You can have a finer control over the
+transactions using an :ref:`autocommit transaction <autocommit>` and/or
+:ref:`transaction contexts <transaction-context>`.
+
+.. warning::
+
+ By default even a simple :sql:`SELECT` will start a transaction: in
+ long-running programs, if no further action is taken, the session will
+ remain *idle in transaction*, an undesirable condition for several
+ reasons (locks are held by the session, tables bloat...). For long lived
+ scripts, either make sure to terminate a transaction as soon as possible or
+ use an `~Connection.autocommit` connection.
+
+.. hint::
+
+ If a database operation fails with an error message such as
+ *InFailedSqlTransaction: current transaction is aborted, commands ignored
+ until end of transaction block*, it means that **a previous operation
+ failed** and the database session is in a state of error. You need to call
+ `~Connection.rollback()` if you want to keep on using the same connection.
+
+
+.. _autocommit:
+
+Autocommit transactions
+-----------------------
+
+The manual commit requirement can be suspended using `~Connection.autocommit`,
+either as connection attribute or as `~psycopg.Connection.connect()`
+parameter. This may be required to run operations that cannot be executed
+inside a transaction, such as :sql:`CREATE DATABASE`, :sql:`VACUUM`,
+:sql:`CALL` on `stored procedures`__ using transaction control.
+
+.. __: https://www.postgresql.org/docs/current/xproc.html
+
+With an autocommit transaction, the above sequence of operation results in:
+
+.. code-block:: python
+ :emphasize-lines: 1
+
+ with psycopg.connect(autocommit=True) as conn:
+
+ cur = conn.cursor()
+
+ cur.execute("SELECT count(*) FROM my_table")
+ # This function call now only executes:
+ # - SELECT count(*) FROM my_table
+ # and no transaction starts.
+
+ cur.execute("INSERT INTO data VALUES (%s)", ("Hello",))
+ # The result of this statement is persisted immediately by the database
+
+ # The connection is closed at the end of the block but, because it is not
+ # in a transaction state, no COMMIT is executed.
+
+An autocommit transaction behaves more as someone coming from :program:`psql`
+would expect. This has a beneficial performance effect, because less queries
+are sent and less operations are performed by the database. The statements,
+however, are not executed in an atomic transaction; if you need to execute
+certain operations inside a transaction, you can achieve that with an
+autocommit connection too, using an explicit :ref:`transaction block
+<transaction-context>`.
+
+
+.. _transaction-context:
+
+Transaction contexts
+--------------------
+
+A more transparent way to make sure that transactions are finalised at the
+right time is to use `!with` `Connection.transaction()` to create a
+transaction context. When the context is entered, a transaction is started;
+when leaving the context the transaction is committed, or it is rolled back if
+an exception is raised inside the block.
+
+Continuing the example above, if you want to use an autocommit connection but
+still wrap selected groups of commands inside an atomic transaction, you can
+use a `!transaction()` context:
+
+.. code-block:: python
+ :emphasize-lines: 8
+
+ with psycopg.connect(autocommit=True) as conn:
+
+ cur = conn.cursor()
+
+ cur.execute("SELECT count(*) FROM my_table")
+ # The connection is autocommit, so no BEGIN executed.
+
+ with conn.transaction():
+ # BEGIN is executed, a transaction started
+
+ cur.execute("INSERT INTO data VALUES (%s)", ("Hello",))
+ cur.execute("INSERT INTO times VALUES (now())")
+ # These two operation run atomically in the same transaction
+
+ # COMMIT is executed at the end of the block.
+ # The connection is in idle state again.
+
+ # The connection is closed at the end of the block.
+
+
+Note that connection blocks can also be used with non-autocommit connections:
+in this case you still need to pay attention to eventual transactions started
+automatically. If an operation starts an implicit transaction, a
+`!transaction()` block will only manage :ref:`a savepoint sub-transaction
+<nested-transactions>`, leaving the caller to deal with the main transaction,
+as explained in :ref:`transactions`:
+
+.. code:: python
+
+ conn = psycopg.connect()
+
+ cur = conn.cursor()
+
+ cur.execute("SELECT count(*) FROM my_table")
+ # This function call executes:
+ # - BEGIN
+ # - SELECT count(*) FROM my_table
+ # So now a transaction has started.
+
+ with conn.transaction():
+ # The block starts with a transaction already open, so it will execute
+ # - SAVEPOINT
+
+ cur.execute("INSERT INTO data VALUES (%s)", ("Hello",))
+
+ # The block was executing a sub-transaction so on exit it will only run:
+ # - RELEASE SAVEPOINT
+ # The transaction is still on.
+
+ conn.close()
+ # No COMMIT was sent: the INSERT was discarded.
+
+If a `!transaction()` block starts when no transaction is active then it will
+manage a proper transaction. In essence, a transaction context tries to leave
+a connection in the state it found it, and leaves you to deal with the wider
+context.
+
+.. hint::
+ The interaction between non-autocommit transactions and transaction
+ contexts is probably surprising. Although the non-autocommit default is
+ what's demanded by the DBAPI, the personal preference of several experienced
+ developers is to:
+
+ - use a connection block: ``with psycopg.connect(...) as conn``;
+ - use an autocommit connection, either passing `!autocommit=True` as
+ `!connect()` parameter or setting the attribute ``conn.autocommit =
+ True``;
+ - use `!with conn.transaction()` blocks to manage transactions only where
+ needed.
+
+
+.. _nested-transactions:
+
+Nested transactions
+^^^^^^^^^^^^^^^^^^^
+
+Transaction blocks can be also nested (internal transaction blocks are
+implemented using SAVEPOINT__): an exception raised inside an inner block
+has a chance of being handled and not completely fail outer operations. The
+following is an example where a series of operations interact with the
+database: operations are allowed to fail; at the end we also want to store the
+number of operations successfully processed.
+
+.. __: https://www.postgresql.org/docs/current/sql-savepoint.html
+
+.. code:: python
+
+ with conn.transaction() as tx1:
+ num_ok = 0
+ for operation in operations:
+ try:
+ with conn.transaction() as tx2:
+ unreliable_operation(conn, operation)
+ except Exception:
+ logger.exception(f"{operation} failed")
+ else:
+ num_ok += 1
+
+ save_number_of_successes(conn, num_ok)
+
+If `!unreliable_operation()` causes an error, including an operation causing a
+database error, all its changes will be reverted. The exception bubbles up
+outside the block: in the example it is intercepted by the `!try` so that the
+loop can complete. The outermost block is unaffected (unless other errors
+happen there).
+
+You can also write code to explicitly roll back any currently active
+transaction block, by raising the `Rollback` exception. The exception "jumps"
+to the end of a transaction block, rolling back its transaction but allowing
+the program execution to continue from there. By default the exception rolls
+back the innermost transaction block, but any current block can be specified
+as the target. In the following example, a hypothetical `!CancelCommand`
+may stop the processing and cancel any operation previously performed,
+but not entirely committed yet.
+
+.. code:: python
+
+ from psycopg import Rollback
+
+ with conn.transaction() as outer_tx:
+ for command in commands():
+ with conn.transaction() as inner_tx:
+ if isinstance(command, CancelCommand):
+ raise Rollback(outer_tx)
+ process_command(command)
+
+ # If `Rollback` is raised, it would propagate only up to this block,
+ # and the program would continue from here with no exception.
+
+
+.. _transaction-characteristics:
+
+Transaction characteristics
+---------------------------
+
+You can set `transaction parameters`__ for the transactions that Psycopg
+handles. They affect the transactions started implicitly by non-autocommit
+transactions and the ones started explicitly by `Connection.transaction()` for
+both autocommit and non-autocommit transactions. Leaving these parameters as
+`!None` will use the server's default behaviour (which is controlled
+by server settings such as default_transaction_isolation__).
+
+.. __: https://www.postgresql.org/docs/current/sql-set-transaction.html
+.. __: https://www.postgresql.org/docs/current/runtime-config-client.html
+ #GUC-DEFAULT-TRANSACTION-ISOLATION
+
+In order to set these parameters you can use the connection attributes
+`~Connection.isolation_level`, `~Connection.read_only`,
+`~Connection.deferrable`. For async connections you must use the equivalent
+`~AsyncConnection.set_isolation_level()` method and similar. The parameters
+can only be changed if there isn't a transaction already active on the
+connection.
+
+.. warning::
+
+ Applications running at `~IsolationLevel.REPEATABLE_READ` or
+ `~IsolationLevel.SERIALIZABLE` isolation level are exposed to serialization
+ failures. `In certain concurrent update cases`__, PostgreSQL will raise an
+ exception looking like::
+
+ psycopg2.errors.SerializationFailure: could not serialize access
+ due to concurrent update
+
+ In this case the application must be prepared to repeat the operation that
+ caused the exception.
+
+ .. __: https://www.postgresql.org/docs/current/transaction-iso.html
+ #XACT-REPEATABLE-READ
+
+
+.. index::
+ pair: Two-phase commit; Transaction
+
+.. _two-phase-commit:
+
+Two-Phase Commit protocol support
+---------------------------------
+
+.. versionadded:: 3.1
+
+Psycopg exposes the two-phase commit features available in PostgreSQL
+implementing the `two-phase commit extensions`__ proposed by the DBAPI.
+
+The DBAPI model of two-phase commit is inspired by the `XA specification`__,
+according to which transaction IDs are formed from three components:
+
+- a format ID (non-negative 32 bit integer)
+- a global transaction ID (string not longer than 64 bytes)
+- a branch qualifier (string not longer than 64 bytes)
+
+For a particular global transaction, the first two components will be the same
+for all the resources. Every resource will be assigned a different branch
+qualifier.
+
+According to the DBAPI specification, a transaction ID is created using the
+`Connection.xid()` method. Once you have a transaction id, a distributed
+transaction can be started with `Connection.tpc_begin()`, prepared using
+`~Connection.tpc_prepare()` and completed using `~Connection.tpc_commit()` or
+`~Connection.tpc_rollback()`. Transaction IDs can also be retrieved from the
+database using `~Connection.tpc_recover()` and completed using the above
+`!tpc_commit()` and `!tpc_rollback()`.
+
+PostgreSQL doesn't follow the XA standard though, and the ID for a PostgreSQL
+prepared transaction can be any string up to 200 characters long. Psycopg's
+`Xid` objects can represent both XA-style transactions IDs (such as the ones
+created by the `!xid()` method) and PostgreSQL transaction IDs identified by
+an unparsed string.
+
+The format in which the Xids are converted into strings passed to the
+database is the same employed by the `PostgreSQL JDBC driver`__: this should
+allow interoperation between tools written in Python and in Java. For example
+a recovery tool written in Python would be able to recognize the components of
+transactions produced by a Java program.
+
+For further details see the documentation for the :ref:`tpc-methods`.
+
+.. __: https://www.python.org/dev/peps/pep-0249/#optional-two-phase-commit-extensions
+.. __: https://publications.opengroup.org/c193
+.. __: https://jdbc.postgresql.org/
diff --git a/docs/basic/usage.rst b/docs/basic/usage.rst
new file mode 100644
index 0000000..6c69fe8
--- /dev/null
+++ b/docs/basic/usage.rst
@@ -0,0 +1,232 @@
+.. currentmodule:: psycopg
+
+.. _module-usage:
+
+Basic module usage
+==================
+
+The basic Psycopg usage is common to all the database adapters implementing
+the `DB-API`__ protocol. Other database adapters, such as the builtin
+`sqlite3` or `psycopg2`, have roughly the same pattern of interaction.
+
+.. __: https://www.python.org/dev/peps/pep-0249/
+
+
+.. index::
+ pair: Example; Usage
+
+.. _usage:
+
+Main objects in Psycopg 3
+-------------------------
+
+Here is an interactive session showing some of the basic commands:
+
+.. code:: python
+
+ # Note: the module name is psycopg, not psycopg3
+ import psycopg
+
+ # Connect to an existing database
+ with psycopg.connect("dbname=test user=postgres") as conn:
+
+ # Open a cursor to perform database operations
+ with conn.cursor() as cur:
+
+ # Execute a command: this creates a new table
+ cur.execute("""
+ CREATE TABLE test (
+ id serial PRIMARY KEY,
+ num integer,
+ data text)
+ """)
+
+ # Pass data to fill a query placeholders and let Psycopg perform
+ # the correct conversion (no SQL injections!)
+ cur.execute(
+ "INSERT INTO test (num, data) VALUES (%s, %s)",
+ (100, "abc'def"))
+
+ # Query the database and obtain data as Python objects.
+ cur.execute("SELECT * FROM test")
+ cur.fetchone()
+ # will return (1, 100, "abc'def")
+
+ # You can use `cur.fetchmany()`, `cur.fetchall()` to return a list
+ # of several records, or even iterate on the cursor
+ for record in cur:
+ print(record)
+
+ # Make the changes to the database persistent
+ conn.commit()
+
+
+In the example you can see some of the main objects and methods and how they
+relate to each other:
+
+- The function `~Connection.connect()` creates a new database session and
+ returns a new `Connection` instance. `AsyncConnection.connect()`
+ creates an `asyncio` connection instead.
+
+- The `~Connection` class encapsulates a database session. It allows to:
+
+ - create new `~Cursor` instances using the `~Connection.cursor()` method to
+ execute database commands and queries,
+
+ - terminate transactions using the methods `~Connection.commit()` or
+ `~Connection.rollback()`.
+
+- The class `~Cursor` allows interaction with the database:
+
+ - send commands to the database using methods such as `~Cursor.execute()`
+ and `~Cursor.executemany()`,
+
+ - retrieve data from the database, iterating on the cursor or using methods
+ such as `~Cursor.fetchone()`, `~Cursor.fetchmany()`, `~Cursor.fetchall()`.
+
+- Using these objects as context managers (i.e. using `!with`) will make sure
+ to close them and free their resources at the end of the block (notice that
+ :ref:`this is different from psycopg2 <diff-with>`).
+
+
+.. seealso::
+
+ A few important topics you will have to deal with are:
+
+ - :ref:`query-parameters`.
+ - :ref:`types-adaptation`.
+ - :ref:`transactions`.
+
+
+Shortcuts
+---------
+
+The pattern above is familiar to `!psycopg2` users. However, Psycopg 3 also
+exposes a few simple extensions which make the above pattern leaner:
+
+- the `Connection` objects exposes an `~Connection.execute()` method,
+ equivalent to creating a cursor, calling its `~Cursor.execute()` method, and
+ returning it.
+
+ .. code::
+
+ # In Psycopg 2
+ cur = conn.cursor()
+ cur.execute(...)
+
+ # In Psycopg 3
+ cur = conn.execute(...)
+
+- The `Cursor.execute()` method returns `!self`. This means that you can chain
+ a fetch operation, such as `~Cursor.fetchone()`, to the `!execute()` call:
+
+ .. code::
+
+ # In Psycopg 2
+ cur.execute(...)
+ record = cur.fetchone()
+
+ cur.execute(...)
+ for record in cur:
+ ...
+
+ # In Psycopg 3
+ record = cur.execute(...).fetchone()
+
+ for record in cur.execute(...):
+ ...
+
+Using them together, in simple cases, you can go from creating a connection to
+using a result in a single expression:
+
+.. code::
+
+ print(psycopg.connect(DSN).execute("SELECT now()").fetchone()[0])
+ # 2042-07-12 18:15:10.706497+01:00
+
+
+.. index::
+ pair: Connection; `!with`
+
+.. _with-connection:
+
+Connection context
+------------------
+
+Psycopg 3 `Connection` can be used as a context manager:
+
+.. code:: python
+
+ with psycopg.connect() as conn:
+ ... # use the connection
+
+ # the connection is now closed
+
+When the block is exited, if there is a transaction open, it will be
+committed. If an exception is raised within the block the transaction is
+rolled back. In both cases the connection is closed. It is roughly the
+equivalent of:
+
+.. code:: python
+
+ conn = psycopg.connect()
+ try:
+ ... # use the connection
+ except BaseException:
+ conn.rollback()
+ else:
+ conn.commit()
+ finally:
+ conn.close()
+
+.. note::
+ This behaviour is not what `!psycopg2` does: in `!psycopg2` :ref:`there is
+ no final close() <pg2:with>` and the connection can be used in several
+ `!with` statements to manage different transactions. This behaviour has
+ been considered non-standard and surprising so it has been replaced by the
+ more explicit `~Connection.transaction()` block.
+
+Note that, while the above pattern is what most people would use, `connect()`
+doesn't enter a block itself, but returns an "un-entered" connection, so that
+it is still possible to use a connection regardless of the code scope and the
+developer is free to use (and responsible for calling) `~Connection.commit()`,
+`~Connection.rollback()`, `~Connection.close()` as and where needed.
+
+.. warning::
+ If a connection is just left to go out of scope, the way it will behave
+ with or without the use of a `!with` block is different:
+
+ - if the connection is used without a `!with` block, the server will find
+ a connection closed INTRANS and roll back the current transaction;
+
+ - if the connection is used with a `!with` block, there will be an
+ explicit COMMIT and the operations will be finalised.
+
+ You should use a `!with` block when your intention is just to execute a
+ set of operations and then committing the result, which is the most usual
+ thing to do with a connection. If your connection life cycle and
+ transaction pattern is different, and want more control on it, the use
+ without `!with` might be more convenient.
+
+ See :ref:`transactions` for more information.
+
+`AsyncConnection` can be also used as context manager, using ``async with``,
+but be careful about its quirkiness: see :ref:`async-with` for details.
+
+
+Adapting pyscopg to your program
+--------------------------------
+
+The above :ref:`pattern of use <usage>` only shows the default behaviour of
+the adapter. Psycopg can be customised in several ways, to allow the smoothest
+integration between your Python program and your PostgreSQL database:
+
+- If your program is concurrent and based on `asyncio` instead of on
+ threads/processes, you can use :ref:`async connections and cursors <async>`.
+
+- If you want to customise the objects that the cursor returns, instead of
+ receiving tuples, you can specify your :ref:`row factories <row-factories>`.
+
+- If you want to customise how Python values and PostgreSQL types are mapped
+ into each other, beside the :ref:`basic type mapping <types-adaptation>`,
+ you can :ref:`configure your types <adaptation>`.
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 0000000..a20894b
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,110 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+
+import sys
+from pathlib import Path
+
+import psycopg
+
+docs_dir = Path(__file__).parent
+sys.path.append(str(docs_dir / "lib"))
+
+
+# -- Project information -----------------------------------------------------
+
+project = "psycopg"
+copyright = "2020, Daniele Varrazzo and The Psycopg Team"
+author = "Daniele Varrazzo"
+release = psycopg.__version__
+
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.intersphinx",
+ "sphinx_autodoc_typehints",
+ "sql_role",
+ "ticket_role",
+ "pg3_docs",
+ "libpq_docs",
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"]
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The announcement may be in the website but not shipped with the docs
+ann_file = docs_dir / "../../templates/docs3-announcement.html"
+if ann_file.exists():
+ with ann_file.open() as f:
+ announcement = f.read()
+else:
+ announcement = ""
+
+html_css_files = ["psycopg.css"]
+
+# The name of the Pygments (syntax highlighting) style to use.
+# Some that I've check don't suck:
+# default lovelace tango algol_nu
+# list: from pygments.styles import STYLE_MAP; print(sorted(STYLE_MAP.keys()))
+pygments_style = "tango"
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+html_theme = "furo"
+html_show_sphinx = True
+html_show_sourcelink = False
+html_theme_options = {
+ "announcement": announcement,
+ "sidebar_hide_name": False,
+ "light_logo": "psycopg.svg",
+ "dark_logo": "psycopg.svg",
+ "light_css_variables": {
+ "admonition-font-size": "1rem",
+ },
+}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["_static"]
+
+# The reST default role (used for this markup: `text`) to use for all documents.
+default_role = "obj"
+
+intersphinx_mapping = {
+ "py": ("https://docs.python.org/3", None),
+ "pg2": ("https://www.psycopg.org/docs/", None),
+}
+
+autodoc_member_order = "bysource"
+
+# PostgreSQL docs version to link libpq functions to
+libpq_docs_version = "14"
+
+# Where to point on :ticket: role
+ticket_url = "https://github.com/psycopg/psycopg/issues/%s"
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 0000000..916eeb0
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,52 @@
+===================================================
+Psycopg 3 -- PostgreSQL database adapter for Python
+===================================================
+
+Psycopg 3 is a newly designed PostgreSQL_ database adapter for the Python_
+programming language.
+
+Psycopg 3 presents a familiar interface for everyone who has used
+`Psycopg 2`_ or any other `DB-API 2.0`_ database adapter, but allows to use
+more modern PostgreSQL and Python features, such as:
+
+- :ref:`Asynchronous support <async>`
+- :ref:`COPY support from Python objects <copy>`
+- :ref:`A redesigned connection pool <connection-pools>`
+- :ref:`Support for static typing <static-typing>`
+- :ref:`Server-side parameters binding <server-side-binding>`
+- :ref:`Prepared statements <prepared-statements>`
+- :ref:`Statements pipeline <pipeline-mode>`
+- :ref:`Binary communication <binary-data>`
+- :ref:`Direct access to the libpq functionalities <psycopg.pq>`
+
+.. _Python: https://www.python.org/
+.. _PostgreSQL: https://www.postgresql.org/
+.. _Psycopg 2: https://www.psycopg.org/docs/
+.. _DB-API 2.0: https://www.python.org/dev/peps/pep-0249/
+
+
+Documentation
+=============
+
+.. toctree::
+ :maxdepth: 2
+
+ basic/index
+ advanced/index
+ api/index
+
+Release notes
+-------------
+
+.. toctree::
+ :maxdepth: 1
+
+ news
+ news_pool
+
+
+Indices and tables
+------------------
+
+* :ref:`genindex`
+* :ref:`modindex`
diff --git a/docs/lib/libpq_docs.py b/docs/lib/libpq_docs.py
new file mode 100644
index 0000000..b8e01f0
--- /dev/null
+++ b/docs/lib/libpq_docs.py
@@ -0,0 +1,182 @@
+"""
+Sphinx plugin to link to the libpq documentation.
+
+Add the ``:pq:`` role, to create a link to a libpq function, e.g. ::
+
+ :pq:`PQlibVersion()`
+
+will link to::
+
+ https://www.postgresql.org/docs/current/libpq-misc.html #LIBPQ-PQLIBVERSION
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import logging
+import urllib.request
+from pathlib import Path
+from functools import lru_cache
+from html.parser import HTMLParser
+
+from docutils import nodes, utils
+from docutils.parsers.rst import roles
+
+logger = logging.getLogger("sphinx.libpq_docs")
+
+
+class LibpqParser(HTMLParser):
+ def __init__(self, data, version="current"):
+ super().__init__()
+ self.data = data
+ self.version = version
+
+ self.section_id = None
+ self.varlist_id = None
+ self.in_term = False
+ self.in_func = False
+
+ def handle_starttag(self, tag, attrs):
+ if tag == "sect1":
+ self.handle_sect1(tag, attrs)
+ elif tag == "varlistentry":
+ self.handle_varlistentry(tag, attrs)
+ elif tag == "term":
+ self.in_term = True
+ elif tag == "function":
+ self.in_func = True
+
+ def handle_endtag(self, tag):
+ if tag == "term":
+ self.in_term = False
+ elif tag == "function":
+ self.in_func = False
+
+ def handle_data(self, data):
+ if not (self.in_term and self.in_func):
+ return
+
+ self.add_function(data)
+
+ def handle_sect1(self, tag, attrs):
+ attrs = dict(attrs)
+ if "id" in attrs:
+ self.section_id = attrs["id"]
+
+ def handle_varlistentry(self, tag, attrs):
+ attrs = dict(attrs)
+ if "id" in attrs:
+ self.varlist_id = attrs["id"]
+
+ def add_function(self, func_name):
+ self.data[func_name] = self.get_func_url()
+
+ def get_func_url(self):
+ assert self.section_id, "<sect1> tag not found"
+ assert self.varlist_id, "<varlistentry> tag not found"
+ return self._url_pattern.format(
+ version=self.version,
+ section=self.section_id,
+ func_id=self.varlist_id.upper(),
+ )
+
+ _url_pattern = "https://www.postgresql.org/docs/{version}/{section}.html#{func_id}"
+
+
+class LibpqReader:
+ # must be set before using the rest of the class.
+ app = None
+
+ _url_pattern = (
+ "https://raw.githubusercontent.com/postgres/postgres/REL_{ver}_STABLE"
+ "/doc/src/sgml/libpq.sgml"
+ )
+
+ data = None
+
+ def get_url(self, func):
+ if not self.data:
+ self.parse()
+
+ return self.data[func]
+
+ def parse(self):
+ if not self.local_file.exists():
+ self.download()
+
+ logger.info("parsing libpq docs from %s", self.local_file)
+ self.data = {}
+ parser = LibpqParser(self.data, version=self.version)
+ with self.local_file.open("r") as f:
+ parser.feed(f.read())
+
+ def download(self):
+ filename = os.environ.get("LIBPQ_DOCS_FILE")
+ if filename:
+ logger.info("reading postgres libpq docs from %s", filename)
+ with open(filename, "rb") as f:
+ data = f.read()
+ else:
+ logger.info("downloading postgres libpq docs from %s", self.sgml_url)
+ data = urllib.request.urlopen(self.sgml_url).read()
+
+ with self.local_file.open("wb") as f:
+ f.write(data)
+
+ @property
+ def local_file(self):
+ return Path(self.app.doctreedir) / f"libpq-{self.version}.sgml"
+
+ @property
+ def sgml_url(self):
+ return self._url_pattern.format(ver=self.version)
+
+ @property
+ def version(self):
+ return self.app.config.libpq_docs_version
+
+
+@lru_cache()
+def get_reader():
+ return LibpqReader()
+
+
+def pq_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
+ text = utils.unescape(text)
+
+ reader = get_reader()
+ if "(" in text:
+ func, noise = text.split("(", 1)
+ noise = "(" + noise
+
+ else:
+ func = text
+ noise = ""
+
+ try:
+ url = reader.get_url(func)
+ except KeyError:
+ msg = inliner.reporter.warning(
+ f"function {func} not found in libpq {reader.version} docs"
+ )
+ prb = inliner.problematic(rawtext, rawtext, msg)
+ return [prb], [msg]
+
+ # For a function f(), include the () in the signature for consistency
+ # with a normal `thing()`
+ if noise == "()":
+ func, noise = func + noise, ""
+
+ the_nodes = []
+ the_nodes.append(nodes.reference(func, func, refuri=url))
+ if noise:
+ the_nodes.append(nodes.Text(noise))
+
+ return [nodes.literal("", "", *the_nodes, **options)], []
+
+
+def setup(app):
+ app.add_config_value("libpq_docs_version", "14", "html")
+ roles.register_local_role("pq", pq_role)
+ get_reader().app = app
diff --git a/docs/lib/pg3_docs.py b/docs/lib/pg3_docs.py
new file mode 100644
index 0000000..05a6876
--- /dev/null
+++ b/docs/lib/pg3_docs.py
@@ -0,0 +1,197 @@
+"""
+Customisation for docs generation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+import logging
+import importlib
+from typing import Dict
+from collections import deque
+
+
+def process_docstring(app, what, name, obj, options, lines):
+ pass
+
+
+def before_process_signature(app, obj, bound_method):
+ ann = getattr(obj, "__annotations__", {})
+ if "return" in ann:
+ # Drop "return: None" from the function signatures
+ if ann["return"] is None:
+ del ann["return"]
+
+
+def process_signature(app, what, name, obj, options, signature, return_annotation):
+ pass
+
+
+def setup(app):
+ app.connect("autodoc-process-docstring", process_docstring)
+ app.connect("autodoc-process-signature", process_signature)
+ app.connect("autodoc-before-process-signature", before_process_signature)
+
+ import psycopg # type: ignore
+
+ recover_defined_module(
+ psycopg, skip_modules=["psycopg._dns", "psycopg.types.shapely"]
+ )
+ monkeypatch_autodoc()
+
+ # Disable warnings in sphinx_autodoc_typehints because it doesn't seem that
+ # there is a workaround for: "WARNING: Cannot resolve forward reference in
+ # type annotations"
+ logger = logging.getLogger("sphinx.sphinx_autodoc_typehints")
+ logger.setLevel(logging.ERROR)
+
+
+# Classes which may have __module__ overwritten
+recovered_classes: Dict[type, str] = {}
+
+
+def recover_defined_module(m, skip_modules=()):
+ """
+ Find the module where classes with __module__ attribute hacked were defined.
+
+ Autodoc will get confused and will fail to inspect attribute docstrings
+ (e.g. from enums and named tuples).
+
+ Save the classes recovered in `recovered_classes`, to be used by
+ `monkeypatch_autodoc()`.
+
+ """
+ mdir = os.path.split(m.__file__)[0]
+ for fn in walk_modules(mdir):
+ assert fn.startswith(mdir)
+ modname = os.path.splitext(fn[len(mdir) + 1 :])[0].replace("/", ".")
+ modname = f"{m.__name__}.{modname}"
+ if modname in skip_modules:
+ continue
+ with open(fn) as f:
+ classnames = re.findall(r"^class\s+([^(:]+)", f.read(), re.M)
+ for cls in classnames:
+ cls = deep_import(f"{modname}.{cls}")
+ if cls.__module__ != modname:
+ recovered_classes[cls] = modname
+
+
+def monkeypatch_autodoc():
+ """
+ Patch autodoc in order to use information found by `recover_defined_module`.
+ """
+ from sphinx.ext.autodoc import Documenter, AttributeDocumenter
+
+ orig_doc_get_real_modname = Documenter.get_real_modname
+ orig_attr_get_real_modname = AttributeDocumenter.get_real_modname
+ orig_attr_add_content = AttributeDocumenter.add_content
+
+ def fixed_doc_get_real_modname(self):
+ if self.object in recovered_classes:
+ return recovered_classes[self.object]
+ return orig_doc_get_real_modname(self)
+
+ def fixed_attr_get_real_modname(self):
+ if self.parent in recovered_classes:
+ return recovered_classes[self.parent]
+ return orig_attr_get_real_modname(self)
+
+ def fixed_attr_add_content(self, more_content):
+ """
+ Replace a docstring such as::
+
+ .. py:attribute:: ConnectionInfo.dbname
+ :module: psycopg
+
+ The database name of the connection.
+
+ :rtype: :py:class:`str`
+
+ into:
+
+ .. py:attribute:: ConnectionInfo.dbname
+ :type: str
+ :module: psycopg
+
+ The database name of the connection.
+
+ which creates a more compact representation of a property.
+
+ """
+ orig_attr_add_content(self, more_content)
+ if not isinstance(self.object, property):
+ return
+ iret, mret = match_in_lines(r"\s*:rtype: (.*)", self.directive.result)
+ iatt, matt = match_in_lines(r"\.\.", self.directive.result)
+ if not (mret and matt):
+ return
+ self.directive.result.pop(iret)
+ self.directive.result.insert(
+ iatt + 1,
+ f"{self.indent}:type: {unrest(mret.group(1))}",
+ source=self.get_sourcename(),
+ )
+
+ Documenter.get_real_modname = fixed_doc_get_real_modname
+ AttributeDocumenter.get_real_modname = fixed_attr_get_real_modname
+ AttributeDocumenter.add_content = fixed_attr_add_content
+
+
+def match_in_lines(pattern, lines):
+ """Match a regular expression against a list of strings.
+
+ Return the index of the first matched line and the match object.
+ None, None if nothing matched.
+ """
+ for i, line in enumerate(lines):
+ m = re.match(pattern, line)
+ if m:
+ return i, m
+ else:
+ return None, None
+
+
+def unrest(s):
+ r"""remove the reST markup from a string
+
+ e.g. :py:data:`~typing.Optional`\[:py:class:`int`] -> Optional[int]
+
+ required because :type: does the types lookup itself apparently.
+ """
+ s = re.sub(r":[^`]*:`~?([^`]*)`", r"\1", s) # drop role
+ s = re.sub(r"\\(.)", r"\1", s) # drop escape
+
+ # note that ~psycopg.pq.ConnStatus is converted to pq.ConnStatus
+ # which should be interpreted well if currentmodule is set ok.
+ s = re.sub(r"(?:typing|psycopg)\.", "", s) # drop unneeded modules
+ s = re.sub(r"~", "", s) # drop the tilde
+
+ return s
+
+
+def walk_modules(d):
+ for root, dirs, files in os.walk(d):
+ for f in files:
+ if f.endswith(".py"):
+ yield f"{root}/{f}"
+
+
+def deep_import(name):
+ parts = deque(name.split("."))
+ seen = []
+ if not parts:
+ raise ValueError("name must be a dot-separated name")
+
+ seen.append(parts.popleft())
+ thing = importlib.import_module(seen[-1])
+ while parts:
+ attr = parts.popleft()
+ seen.append(attr)
+
+ if hasattr(thing, attr):
+ thing = getattr(thing, attr)
+ else:
+ thing = importlib.import_module(".".join(seen))
+
+ return thing
diff --git a/docs/lib/sql_role.py b/docs/lib/sql_role.py
new file mode 100644
index 0000000..a40c9f4
--- /dev/null
+++ b/docs/lib/sql_role.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+"""
+ sql role
+ ~~~~~~~~
+
+ An interpreted text role to style SQL syntax in Psycopg documentation.
+
+ :copyright: Copyright 2010 by Daniele Varrazzo.
+ :copyright: Copyright 2020 The Psycopg Team.
+"""
+
+from docutils import nodes, utils
+from docutils.parsers.rst import roles
+
+
+def sql_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
+ text = utils.unescape(text)
+ options["classes"] = ["sql"]
+ return [nodes.literal(rawtext, text, **options)], []
+
+
+def setup(app):
+ roles.register_local_role("sql", sql_role)
diff --git a/docs/lib/ticket_role.py b/docs/lib/ticket_role.py
new file mode 100644
index 0000000..24ec873
--- /dev/null
+++ b/docs/lib/ticket_role.py
@@ -0,0 +1,50 @@
+# type: ignore
+"""
+ ticket role
+ ~~~~~~~~~~~
+
+ An interpreted text role to link docs to tickets issues.
+
+ :copyright: Copyright 2013 by Daniele Varrazzo.
+ :copyright: Copyright 2021 The Psycopg Team
+"""
+
+import re
+from docutils import nodes, utils
+from docutils.parsers.rst import roles
+
+
+def ticket_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
+ cfg = inliner.document.settings.env.app.config
+ if cfg.ticket_url is None:
+ msg = inliner.reporter.warning(
+ "ticket not configured: please configure ticket_url in conf.py"
+ )
+ prb = inliner.problematic(rawtext, rawtext, msg)
+ return [prb], [msg]
+
+ rv = [nodes.Text(name + " ")]
+ tokens = re.findall(r"(#?\d+)|([^\d#]+)", text)
+ for ticket, noise in tokens:
+ if ticket:
+ num = int(ticket.replace("#", ""))
+
+ url = cfg.ticket_url % num
+ roles.set_classes(options)
+ node = nodes.reference(
+ ticket, utils.unescape(ticket), refuri=url, **options
+ )
+
+ rv.append(node)
+
+ else:
+ assert noise
+ rv.append(nodes.Text(noise))
+
+ return rv, []
+
+
+def setup(app):
+ app.add_config_value("ticket_url", None, "env")
+ app.add_role("ticket", ticket_role)
+ app.add_role("tickets", ticket_role)
diff --git a/docs/news.rst b/docs/news.rst
new file mode 100644
index 0000000..46dfbe7
--- /dev/null
+++ b/docs/news.rst
@@ -0,0 +1,285 @@
+.. currentmodule:: psycopg
+
+.. index::
+ single: Release notes
+ single: News
+
+``psycopg`` release notes
+=========================
+
+Current release
+---------------
+
+Psycopg 3.1.7
+^^^^^^^^^^^^^
+
+- Fix server-side cursors using row factories (:ticket:`#464`).
+
+
+Psycopg 3.1.6
+^^^^^^^^^^^^^
+
+- Fix `cursor.copy()` with cursors using row factories (:ticket:`#460`).
+
+
+Psycopg 3.1.5
+^^^^^^^^^^^^^
+
+- Fix array loading slowness compared to psycopg2 (:ticket:`#359`).
+- Improve performance around network communication (:ticket:`#414`).
+- Return `!bytes` instead of `!memoryview` from `pq.Encoding` methods
+ (:ticket:`#422`).
+- Fix `Cursor.rownumber` to return `!None` when the result has no row to fetch
+ (:ticket:`#437`).
+- Avoid error in Pyright caused by aliasing `!TypeAlias` (:ticket:`#439`).
+- Fix `Copy.set_types()` used with `varchar` and `name` types (:ticket:`#452`).
+- Improve performance using :ref:`row-factories` (:ticket:`#457`).
+
+
+Psycopg 3.1.4
+^^^^^^^^^^^^^
+
+- Include :ref:`error classes <sqlstate-exceptions>` defined in PostgreSQL 15.
+- Add support for Python 3.11 (:ticket:`#305`).
+- Build binary packages with libpq from PostgreSQL 15.0.
+
+
+Psycopg 3.1.3
+^^^^^^^^^^^^^
+
+- Restore the state of the connection if `Cursor.stream()` is terminated
+ prematurely (:ticket:`#382`).
+- Fix regression introduced in 3.1 with different named tuples mangling rules
+ for non-ascii attribute names (:ticket:`#386`).
+- Fix handling of queries with escaped percent signs (``%%``) in `ClientCursor`
+ (:ticket:`#399`).
+- Fix possible duplicated BEGIN statements emitted in pipeline mode
+ (:ticket:`#401`).
+
+
+Psycopg 3.1.2
+^^^^^^^^^^^^^
+
+- Fix handling of certain invalid time zones causing problems on Windows
+ (:ticket:`#371`).
+- Fix segfault occurring when a loader fails initialization (:ticket:`#372`).
+- Fix invalid SAVEPOINT issued when entering `Connection.transaction()` within
+ a pipeline using an implicit transaction (:ticket:`#374`).
+- Fix queries with repeated named parameters in `ClientCursor` (:ticket:`#378`).
+- Distribute macOS arm64 (Apple M1) binary packages (:ticket:`#344`).
+
+
+Psycopg 3.1.1
+^^^^^^^^^^^^^
+
+- Work around broken Homebrew installation of the libpq in a non-standard path
+ (:ticket:`#364`)
+- Fix possible "unrecognized service" error in async connection when no port
+ is specified (:ticket:`#366`).
+
+
+Psycopg 3.1
+-----------
+
+- Add :ref:`Pipeline mode <pipeline-mode>` (:ticket:`#74`).
+- Add :ref:`client-side-binding-cursors` (:ticket:`#101`).
+- Add `CockroachDB <https://www.cockroachlabs.com/>`__ support in `psycopg.crdb`
+ (:ticket:`#313`).
+- Add :ref:`Two-Phase Commit <two-phase-commit>` support (:ticket:`#72`).
+- Add :ref:`adapt-enum` (:ticket:`#274`).
+- Add ``returning`` parameter to `~Cursor.executemany()` to retrieve query
+ results (:ticket:`#164`).
+- `~Cursor.executemany()` performance improved by using batch mode internally
+ (:ticket:`#145`).
+- Add parameters to `~Cursor.copy()`.
+- Add :ref:`COPY Writer objects <copy-writers>`.
+- Resolve domain names asynchronously in `AsyncConnection.connect()`
+ (:ticket:`#259`).
+- Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`).
+- Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`).
+- Add ``cursor_factory`` parameter to `Connection` init.
+- Add `Error.pgconn` and `Error.pgresult` attributes (:ticket:`#242`).
+- Restrict queries to be `~typing.LiteralString` as per :pep:`675`
+ (:ticket:`#323`).
+- Add explicit type cast to values converted by `sql.Literal` (:ticket:`#205`).
+- Drop support for Python 3.6.
+
+
+Psycopg 3.0.17
+^^^^^^^^^^^^^^
+
+- Fix segfaults on fork on some Linux systems using `ctypes` implementation
+ (:ticket:`#300`).
+- Load bytea as bytes, not memoryview, using `ctypes` implementation.
+
+
+Psycopg 3.0.16
+^^^^^^^^^^^^^^
+
+- Fix missing `~Cursor.rowcount` after SHOW (:ticket:`#343`).
+- Add scripts to build macOS arm64 packages (:ticket:`#162`).
+
+
+Psycopg 3.0.15
+^^^^^^^^^^^^^^
+
+- Fix wrong escaping of unprintable chars in COPY (nonetheless correctly
+ interpreted by PostgreSQL).
+- Restore the connection to usable state after an error in `~Cursor.stream()`.
+- Raise `DataError` instead of `OverflowError` loading binary intervals
+ out-of-range.
+- Distribute ``manylinux2014`` wheel packages (:ticket:`#124`).
+
+
+Psycopg 3.0.14
+^^^^^^^^^^^^^^
+
+- Raise `DataError` dumping arrays of mixed types (:ticket:`#301`).
+- Fix handling of incorrect server results, with blank sqlstate (:ticket:`#303`).
+- Fix bad Float4 conversion on ppc64le/musllinux (:ticket:`#304`).
+
+
+Psycopg 3.0.13
+^^^^^^^^^^^^^^
+
+- Fix `Cursor.stream()` slowness (:ticket:`#286`).
+- Fix oid for lists of integers, which might cause the server choosing
+ bad plans (:ticket:`#293`).
+- Make `Connection.cancel()` on a closed connection a no-op instead of an
+ error.
+
+
+Psycopg 3.0.12
+^^^^^^^^^^^^^^
+
+- Allow `bytearray`/`memoryview` data too as `Copy.write()` input
+ (:ticket:`#254`).
+- Fix dumping `~enum.IntEnum` in text mode, Python implementation.
+
+
+Psycopg 3.0.11
+^^^^^^^^^^^^^^
+
+- Fix `DataError` loading arrays with dimensions information (:ticket:`#253`).
+- Fix hanging during COPY in case of memory error (:ticket:`#255`).
+- Fix error propagation from COPY worker thread (mentioned in :ticket:`#255`).
+
+
+Psycopg 3.0.10
+^^^^^^^^^^^^^^
+
+- Leave the connection in working state after interrupting a query with Ctrl-C
+ (:ticket:`#231`).
+- Fix `Cursor.description` after a COPY ... TO STDOUT operation
+ (:ticket:`#235`).
+- Fix building on FreeBSD and likely other BSD flavours (:ticket:`#241`).
+
+
+Psycopg 3.0.9
+^^^^^^^^^^^^^
+
+- Set `Error.sqlstate` when an unknown code is received (:ticket:`#225`).
+- Add the `!tzdata` package as a dependency on Windows in order to handle time
+ zones (:ticket:`#223`).
+
+
+Psycopg 3.0.8
+^^^^^^^^^^^^^
+
+- Decode connection errors in the ``client_encoding`` specified in the
+ connection string, if available (:ticket:`#194`).
+- Fix possible warnings in objects deletion on interpreter shutdown
+ (:ticket:`#198`).
+- Don't leave connections in ACTIVE state in case of error during COPY ... TO
+ STDOUT (:ticket:`#203`).
+
+
+Psycopg 3.0.7
+^^^^^^^^^^^^^
+
+- Fix crash in `~Cursor.executemany()` with no input sequence
+ (:ticket:`#179`).
+- Fix wrong `~Cursor.rowcount` after an `~Cursor.executemany()` returning no
+ rows (:ticket:`#178`).
+
+
+Psycopg 3.0.6
+^^^^^^^^^^^^^
+
+- Allow to use `Cursor.description` if the connection is closed
+ (:ticket:`#172`).
+- Don't raise exceptions on `ServerCursor.close()` if the connection is closed
+ (:ticket:`#173`).
+- Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`).
+- Raise `ProgrammingError` if out-of-order exit from transaction contexts is
+ detected (:tickets:`#176, #177`).
+- Add `!CHECK_STANDBY` value to `~pq.ConnStatus` enum.
+
+
+Psycopg 3.0.5
+^^^^^^^^^^^^^
+
+- Fix possible "Too many open files" OS error, reported on macOS but possible
+ on other platforms too (:ticket:`#158`).
+- Don't clobber exceptions if a transaction block exit with error and rollback
+ fails (:ticket:`#165`).
+
+
+Psycopg 3.0.4
+^^^^^^^^^^^^^
+
+- Allow to use the module with strict strings comparison (:ticket:`#147`).
+- Fix segfault on Python 3.6 running in ``-W error`` mode, related to
+ `!backport.zoneinfo` `ticket #109
+ <https://github.com/pganssle/zoneinfo/issues/109>`__.
+- Build binary package with libpq versions not affected by `CVE-2021-23222
+ <https://www.postgresql.org/support/security/CVE-2021-23222/>`__
+ (:ticket:`#149`).
+
+
+Psycopg 3.0.3
+^^^^^^^^^^^^^
+
+- Release musllinux binary packages, compatible with Alpine Linux
+ (:ticket:`#141`).
+- Reduce size of binary package by stripping debug symbols (:ticket:`#142`).
+- Include typing information in the `!psycopg_binary` package.
+
+
+Psycopg 3.0.2
+^^^^^^^^^^^^^
+
+- Fix type hint for `sql.SQL.join()` (:ticket:`#127`).
+- Fix type hint for `Connection.notifies()` (:ticket:`#128`).
+- Fix call to `MultiRange.__setitem__()` with a non-iterable value and a
+ slice, now raising a `TypeError` (:ticket:`#129`).
+- Fix disable cursors methods after close() (:ticket:`#125`).
+
+
+Psycopg 3.0.1
+^^^^^^^^^^^^^
+
+- Fix use of the wrong dumper reusing cursors with the same query but different
+ parameter types (:ticket:`#112`).
+
+
+Psycopg 3.0
+-----------
+
+First stable release. Changed from 3.0b1:
+
+- Add :ref:`adapt-shapely` (:ticket:`#80`).
+- Add :ref:`adapt-multirange` (:ticket:`#75`).
+- Add `pq.__build_version__` constant.
+- Don't use the extended protocol with COPY, (:tickets:`#78, #82`).
+- Add ``context`` parameter to `~Connection.connect()` (:ticket:`#83`).
+- Fix selection of dumper by oid after `~Copy.set_types()`.
+- Drop `!Connection.client_encoding`. Use `ConnectionInfo.encoding` to read
+ it, and a :sql:`SET` statement to change it.
+- Add binary packages for Python 3.10 (:ticket:`#103`).
+
+
+Psycopg 3.0b1
+^^^^^^^^^^^^^
+
+- First public release on PyPI.
diff --git a/docs/news_pool.rst b/docs/news_pool.rst
new file mode 100644
index 0000000..7f212e0
--- /dev/null
+++ b/docs/news_pool.rst
@@ -0,0 +1,81 @@
+.. currentmodule:: psycopg_pool
+
+.. index::
+ single: Release notes
+ single: News
+
+``psycopg_pool`` release notes
+==============================
+
+Current release
+---------------
+
+psycopg_pool 3.1.5
+^^^^^^^^^^^^^^^^^^
+
+- Make sure that `!ConnectionPool.check()` refills an empty pool
+ (:ticket:`#438`).
+- Avoid error in Pyright caused by aliasing `!TypeAlias` (:ticket:`#439`).
+
+
+psycopg_pool 3.1.4
+^^^^^^^^^^^^^^^^^^
+
+- Fix async pool exhausting connections, happening if the pool is created
+ before the event loop is started (:ticket:`#219`).
+
+
+psycopg_pool 3.1.3
+^^^^^^^^^^^^^^^^^^
+
+- Add support for Python 3.11 (:ticket:`#305`).
+
+
+psycopg_pool 3.1.2
+^^^^^^^^^^^^^^^^^^
+
+- Fix possible failure to reconnect after losing connection from the server
+ (:ticket:`#370`).
+
+
+psycopg_pool 3.1.1
+^^^^^^^^^^^^^^^^^^
+
+- Fix race condition on pool creation which might result in the pool not
+ filling (:ticket:`#230`).
+
+
+psycopg_pool 3.1.0
+------------------
+
+- Add :ref:`null-pool` (:ticket:`#148`).
+- Add `ConnectionPool.open()` and ``open`` parameter to the pool init
+ (:ticket:`#151`).
+- Drop support for Python 3.6.
+
+
+psycopg_pool 3.0.3
+^^^^^^^^^^^^^^^^^^
+
+- Raise `!ValueError` if `ConnectionPool` `!min_size` and `!max_size` are both
+ set to 0 (instead of hanging).
+- Raise `PoolClosed` calling `~ConnectionPool.wait()` on a closed pool.
+
+
+psycopg_pool 3.0.2
+^^^^^^^^^^^^^^^^^^
+
+- Remove dependency on the internal `!psycopg._compat` module.
+
+
+psycopg_pool 3.0.1
+^^^^^^^^^^^^^^^^^^
+
+- Don't leave connections idle in transaction after calling
+ `~ConnectionPool.check()` (:ticket:`#144`).
+
+
+psycopg_pool 3.0
+----------------
+
+- First release on PyPI.
diff --git a/docs/pictures/adapt.drawio b/docs/pictures/adapt.drawio
new file mode 100644
index 0000000..75f61ed
--- /dev/null
+++ b/docs/pictures/adapt.drawio
@@ -0,0 +1,107 @@
+<mxfile host="Electron" modified="2021-07-12T13:26:05.192Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/14.6.13 Chrome/89.0.4389.128 Electron/12.0.7 Safari/537.36" etag="kKU1DyIkJcQFc1Rxt__U" compressed="false" version="14.6.13" type="device">
+ <diagram id="THISp3X85jFCtBEH0bao" name="Page-1">
+ <mxGraphModel dx="675" dy="400" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
+ <root>
+ <mxCell id="0" />
+ <mxCell id="1" parent="0" />
+ <mxCell id="uy255Msn6vtulWmyCIR1-12" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontFamily=Courier New;exitX=1;exitY=0.5;exitDx=0;exitDy=0;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-29" target="uy255Msn6vtulWmyCIR1-11">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="280" y="210" as="sourcePoint" />
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Courier New;exitX=1;exitY=0.5;exitDx=0;exitDy=0;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-30" target="uy255Msn6vtulWmyCIR1-14">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="280" y="320" as="sourcePoint" />
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-39" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-11" target="uy255Msn6vtulWmyCIR1-14">
+ <mxGeometry relative="1" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-11" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="330" y="185" width="80" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-40" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-14" target="uy255Msn6vtulWmyCIR1-27">
+ <mxGeometry relative="1" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-14" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="330" y="285" width="80" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontFamily=Courier New;exitX=1;exitY=0.5;exitDx=0;exitDy=0;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-31" target="uy255Msn6vtulWmyCIR1-27">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="280" y="440" as="sourcePoint" />
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-18" value=".cursor()" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="220" y="220" width="80" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-26" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-19" target="uy255Msn6vtulWmyCIR1-25">
+ <mxGeometry relative="1" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-34" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-19" target="uy255Msn6vtulWmyCIR1-29">
+ <mxGeometry relative="1" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-19" value="&lt;b&gt;psycopg&lt;/b&gt;&lt;br&gt;&lt;font face=&quot;Helvetica&quot;&gt;module&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="160" y="75" width="120" height="50" as="geometry" />
+ </mxCell>
+ <UserObject label=".connect()" link="../api/connections.html" id="uy255Msn6vtulWmyCIR1-20">
+ <mxCell style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="220" y="125" width="80" height="20" as="geometry" />
+ </mxCell>
+ </UserObject>
+ <mxCell id="uy255Msn6vtulWmyCIR1-21" value=".execute()" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="220" y="320" width="80" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-37" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-25" target="uy255Msn6vtulWmyCIR1-11">
+ <mxGeometry relative="1" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-25" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="330" y="90" width="80" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-27" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="330" y="385" width="80" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-35" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-29" target="uy255Msn6vtulWmyCIR1-30">
+ <mxGeometry relative="1" as="geometry" />
+ </mxCell>
+ <UserObject label="&lt;b&gt;Connection&lt;/b&gt;&lt;br&gt;&lt;font face=&quot;Helvetica&quot;&gt;object&lt;/font&gt;" link="../api/connections.html" id="uy255Msn6vtulWmyCIR1-29">
+ <mxCell style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="160" y="170" width="120" height="50" as="geometry" />
+ </mxCell>
+ </UserObject>
+ <mxCell id="uy255Msn6vtulWmyCIR1-36" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-30" target="uy255Msn6vtulWmyCIR1-31">
+ <mxGeometry relative="1" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-30" value="&lt;b&gt;Cursor&lt;/b&gt;&lt;br&gt;&lt;font face=&quot;Helvetica&quot;&gt;object&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="160" y="270" width="120" height="50" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-31" value="&lt;b&gt;Transformer&lt;/b&gt;&lt;br&gt;&lt;font face=&quot;Helvetica&quot;&gt;object&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1">
+ <mxGeometry x="160" y="370" width="120" height="50" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-46" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Helvetica;endArrow=none;endFill=0;dashed=1;dashPattern=1 1;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-41">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="310" y="100" as="targetPoint" />
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-41" value="Has a" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Helvetica;" vertex="1" parent="1">
+ <mxGeometry x="300" y="55" width="40" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-45" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Helvetica;endArrow=none;endFill=0;dashed=1;dashPattern=1 1;startSize=4;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-42">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="220" y="150" as="targetPoint" />
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-42" value="Create" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Helvetica;" vertex="1" parent="1">
+ <mxGeometry x="150" y="130" width="40" height="20" as="geometry" />
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-47" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Helvetica;endArrow=none;endFill=0;dashed=1;dashPattern=1 1;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-43">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="370" y="150" as="targetPoint" />
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="uy255Msn6vtulWmyCIR1-43" value="Copy" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Helvetica;" vertex="1" parent="1">
+ <mxGeometry x="394" y="130" width="40" height="20" as="geometry" />
+ </mxCell>
+ </root>
+ </mxGraphModel>
+ </diagram>
+</mxfile>
diff --git a/docs/pictures/adapt.svg b/docs/pictures/adapt.svg
new file mode 100644
index 0000000..2c39755
--- /dev/null
+++ b/docs/pictures/adapt.svg
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" width="285px" height="366px" viewBox="-0.5 -0.5 285 366" style="background-color: rgb(255, 255, 255);"><defs/><g><path d="M 130 140 L 173.63 140" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 140 L 171.88 143.5 L 173.63 140 L 171.88 136.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><path d="M 130 240 L 173.63 240" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 240 L 171.88 243.5 L 173.63 240 L 171.88 236.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><path d="M 220 150 L 220 223.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 220 228.88 L 216.5 221.88 L 220 223.63 L 223.5 221.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="180" y="130" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 140px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="144" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><path d="M 220 250 L 220 323.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 220 328.88 L 216.5 321.88 L 220 323.63 L 223.5 321.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="180" y="230" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 240px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="244" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><path d="M 130 340 L 173.63 340" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 340 L 171.88 343.5 L 173.63 340 L 171.88 336.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="70" y="165" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 175px; margin-left: 72px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.cursor()</div></div></div></foreignObject><text x="72" y="179" fill="#000000" font-family="Courier New" font-size="12px">.cursor()</text></switch></g><path d="M 130 45 L 173.63 45" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 45 L 171.88 48.5 L 173.63 45 L 171.88 41.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><path d="M 70 70 L 70 108.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 70 113.88 L 66.5 106.88 L 70 108.63 L 73.5 106.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="10" y="20" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 45px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>psycopg</b><br /><font face="Helvetica">module</font></div></div></div></foreignObject><text x="70" y="49" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">psycopg...</text></switch></g><a xlink:href="../api/connections.html"><rect x="70" y="70" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 80px; margin-left: 72px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.connect()</div></div></div></foreignObject><text x="72" y="84" fill="#000000" font-family="Courier New" font-size="12px">.connect()</text></switch></g></a><rect x="70" y="265" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 275px; margin-left: 72px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.execute()</div></div></div></foreignObject><text x="72" y="279" fill="#000000" font-family="Courier New" font-size="12px">.execute()</text></switch></g><path d="M 220 55 L 220 123.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 220 128.88 L 216.5 121.88 L 220 123.63 L 223.5 121.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="180" y="35" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 45px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="49" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><rect x="180" y="330" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 340px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="344" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><path d="M 70 165 L 70 208.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 70 213.88 L 66.5 206.88 L 70 208.63 L 73.5 206.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><a xlink:href="../api/connections.html"><rect x="10" y="115" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 140px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>Connection</b><br /><font face="Helvetica">object</font></div></div></div></foreignObject><text x="70" y="144" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">Connection...</text></switch></g></a><path d="M 70 265 L 70 308.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 70 313.88 L 66.5 306.88 L 70 308.63 L 73.5 306.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="10" y="215" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 240px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>Cursor</b><br /><font face="Helvetica">object</font></div></div></div></foreignObject><text x="70" y="244" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">Cursor...</text></switch></g><rect x="10" y="315" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 340px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>Transformer</b><br /><font face="Helvetica">object</font></div></div></div></foreignObject><text x="70" y="344" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">Transformer...</text></switch></g><path d="M 167.14 20 L 160 45" fill="none" stroke="#000000" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke"/><rect x="150" y="0" width="40" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 10px; margin-left: 151px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">Has a</div></div></div></foreignObject><text x="170" y="14" fill="#000000" font-family="Helvetica" font-size="12px" text-anchor="middle">Has a</text></switch></g><path d="M 40 89 L 70 95" fill="none" stroke="#000000" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke"/><rect x="0" y="75" width="40" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 85px; margin-left: 1px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">Create</div></div></div></foreignObject><text x="20" y="89" fill="#000000" font-family="Helvetica" font-size="12px" text-anchor="middle">Create</text></switch></g><path d="M 244 89.55 L 220 95" fill="none" stroke="#000000" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke"/><rect x="244" y="75" width="40" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 85px; margin-left: 245px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">Copy</div></div></div></foreignObject><text x="264" y="89" fill="#000000" font-family="Helvetica" font-size="12px" text-anchor="middle">Copy</text></switch></g></g><switch><g requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"/><a transform="translate(0,-5)" xlink:href="https://www.diagrams.net/doc/faq/svg-export-text-problems" target="_blank"><text text-anchor="middle" font-size="10px" x="50%" y="100%">Viewer does not support full SVG 1.1</text></a></switch></svg> \ No newline at end of file
diff --git a/docs/release.rst b/docs/release.rst
new file mode 100644
index 0000000..8fcadaf
--- /dev/null
+++ b/docs/release.rst
@@ -0,0 +1,39 @@
+:orphan:
+
+How to make a psycopg release
+=============================
+
+- Change version number in:
+
+ - ``psycopg_c/psycopg_c/version.py``
+ - ``psycopg/psycopg/version.py``
+ - ``psycopg_pool/psycopg_pool/version.py``
+
+- Change docs/news.rst to drop the "unreleased" mark from the version
+
+- Push to GitHub to run `the tests workflow`__.
+
+ .. __: https://github.com/psycopg/psycopg/actions/workflows/tests.yml
+
+- Build the packages by triggering manually the `Build packages workflow`__.
+
+ .. __: https://github.com/psycopg/psycopg/actions/workflows/packages.yml
+
+- If all went fine, create a tag named after the version::
+
+ git tag -a -s 3.0.dev1
+ git push --tags
+
+- Download the ``artifacts.zip`` package from the last Packages workflow run.
+
+- Unpack the packages locally::
+
+ mkdir tmp
+ cd tmp
+ unzip ~/Downloads/artifact.zip
+
+- If the package is a testing one, upload it on TestPyPI with::
+
+ $ twine upload -s -r testpypi *
+
+- If the package is stable, omit ``-r testpypi``.
diff --git a/psycopg/.flake8 b/psycopg/.flake8
new file mode 100644
index 0000000..67fb024
--- /dev/null
+++ b/psycopg/.flake8
@@ -0,0 +1,6 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
+per-file-ignores =
+ # Autogenerated section
+ psycopg/errors.py: E125, E128, E302
diff --git a/psycopg/LICENSE.txt b/psycopg/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/psycopg/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/psycopg/README.rst b/psycopg/README.rst
new file mode 100644
index 0000000..45eeac3
--- /dev/null
+++ b/psycopg/README.rst
@@ -0,0 +1,31 @@
+Psycopg 3: PostgreSQL database adapter for Python
+=================================================
+
+Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python.
+
+This distribution contains the pure Python package ``psycopg``.
+
+
+Installation
+------------
+
+In short, run the following::
+
+ pip install --upgrade pip # to upgrade pip
+ pip install "psycopg[binary,pool]" # to install package and dependencies
+
+If something goes wrong, and for more information about installation, please
+check out the `Installation documentation`__.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html#
+
+
+Hacking
+-------
+
+For development information check out `the project readme`__.
+
+.. __: https://github.com/psycopg/psycopg#readme
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py
new file mode 100644
index 0000000..baadf30
--- /dev/null
+++ b/psycopg/psycopg/__init__.py
@@ -0,0 +1,110 @@
+"""
+psycopg -- PostgreSQL database adapter for Python
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from . import pq # noqa: F401 import early to stabilize side effects
+from . import types
+from . import postgres
+from ._tpc import Xid
+from .copy import Copy, AsyncCopy
+from ._enums import IsolationLevel
+from .cursor import Cursor
+from .errors import Warning, Error, InterfaceError, DatabaseError
+from .errors import DataError, OperationalError, IntegrityError
+from .errors import InternalError, ProgrammingError, NotSupportedError
+from ._column import Column
+from .conninfo import ConnectionInfo
+from ._pipeline import Pipeline, AsyncPipeline
+from .connection import BaseConnection, Connection, Notify
+from .transaction import Rollback, Transaction, AsyncTransaction
+from .cursor_async import AsyncCursor
+from .server_cursor import AsyncServerCursor, ServerCursor
+from .client_cursor import AsyncClientCursor, ClientCursor
+from .connection_async import AsyncConnection
+
+from . import dbapi20
+from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
+from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
+from .dbapi20 import Timestamp, TimestampFromTicks
+
+from .version import __version__ as __version__ # noqa: F401
+
+# Set the logger to a quiet default, can be enabled if needed
+logger = logging.getLogger("psycopg")
+if logger.level == logging.NOTSET:
+ logger.setLevel(logging.WARNING)
+
+# DBAPI compliance
+connect = Connection.connect
+apilevel = "2.0"
+threadsafety = 2
+paramstyle = "pyformat"
+
+# register default adapters for PostgreSQL
+adapters = postgres.adapters # exposed by the package
+postgres.register_default_adapters(adapters)
+
+# After the default ones, because these can deal with the bytea oid better
+dbapi20.register_dbapi20_adapters(adapters)
+
+# Must come after all the types have been registered
+types.array.register_all_arrays(adapters)
+
+# Note: defining the exported methods helps both Sphynx in documenting that
+# this is the canonical place to obtain them and should be used by MyPy too,
+# so that function signatures are consistent with the documentation.
+__all__ = [
+ "AsyncClientCursor",
+ "AsyncConnection",
+ "AsyncCopy",
+ "AsyncCursor",
+ "AsyncPipeline",
+ "AsyncServerCursor",
+ "AsyncTransaction",
+ "BaseConnection",
+ "ClientCursor",
+ "Column",
+ "Connection",
+ "ConnectionInfo",
+ "Copy",
+ "Cursor",
+ "IsolationLevel",
+ "Notify",
+ "Pipeline",
+ "Rollback",
+ "ServerCursor",
+ "Transaction",
+ "Xid",
+ # DBAPI exports
+ "connect",
+ "apilevel",
+ "threadsafety",
+ "paramstyle",
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DatabaseError",
+ "DataError",
+ "OperationalError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "NotSupportedError",
+ # DBAPI type constructors and singletons
+ "Binary",
+ "Date",
+ "DateFromTicks",
+ "Time",
+ "TimeFromTicks",
+ "Timestamp",
+ "TimestampFromTicks",
+ "BINARY",
+ "DATETIME",
+ "NUMBER",
+ "ROWID",
+ "STRING",
+]
diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py
new file mode 100644
index 0000000..a3a6ef8
--- /dev/null
+++ b/psycopg/psycopg/_adapters_map.py
@@ -0,0 +1,289 @@
+"""
+Mapping from types/oids to Dumpers/Loaders
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, List, Optional, Type, TypeVar, Union
+from typing import cast, TYPE_CHECKING
+
+from . import pq
+from . import errors as e
+from .abc import Dumper, Loader
+from ._enums import PyFormat as PyFormat
+from ._cmodule import _psycopg
+from ._typeinfo import TypesRegistry
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection
+
+RV = TypeVar("RV")
+
+
+class AdaptersMap:
+ r"""
+ Establish how types should be converted between Python and PostgreSQL in
+ an `~psycopg.abc.AdaptContext`.
+
+ `!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to
+ define how Python types are converted to PostgreSQL, and maps OIDs to
+ `~psycopg.adapt.Loader` classes to establish how query results are
+ converted to Python.
+
+ Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how
+ types are converted in that context, exposed as the
+ `~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows
+ to customise adaptation in a context without changing separated contexts.
+
+ When a context is created from another context (for instance when a
+ `~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's
+ `!adapters` are used as template for the child's `!adapters`, so that every
+ cursor created from the same connection use the connection's types
+ configuration, but separate connections have independent mappings.
+
+ Once created, `!AdaptersMap` are independent. This means that objects
+ already created are not affected if a wider scope (e.g. the global one) is
+ changed.
+
+ The connections adapters are initialised using a global `!AdptersMap`
+ template, exposed as `psycopg.adapters`: changing such mapping allows to
+ customise the type mapping for every connections created afterwards.
+
+ The object can start empty or copy from another object of the same class.
+ Copies are copy-on-write: if the maps are updated make a copy. This way
+ extending e.g. global map by a connection or a connection map from a cursor
+ is cheap: a copy is only made on customisation.
+ """
+
+ __module__ = "psycopg.adapt"
+
+ types: TypesRegistry
+
+ _dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]]
+ _dumpers_by_oid: List[Dict[int, Type[Dumper]]]
+ _loaders: List[Dict[int, Type[Loader]]]
+
+ # Record if a dumper or loader has an optimised version.
+ _optimised: Dict[type, type] = {}
+
+ def __init__(
+ self,
+ template: Optional["AdaptersMap"] = None,
+ types: Optional[TypesRegistry] = None,
+ ):
+ if template:
+ self._dumpers = template._dumpers.copy()
+ self._own_dumpers = _dumpers_shared.copy()
+ template._own_dumpers = _dumpers_shared.copy()
+
+ self._dumpers_by_oid = template._dumpers_by_oid[:]
+ self._own_dumpers_by_oid = [False, False]
+ template._own_dumpers_by_oid = [False, False]
+
+ self._loaders = template._loaders[:]
+ self._own_loaders = [False, False]
+ template._own_loaders = [False, False]
+
+ self.types = TypesRegistry(template.types)
+
+ else:
+ self._dumpers = {fmt: {} for fmt in PyFormat}
+ self._own_dumpers = _dumpers_owned.copy()
+
+ self._dumpers_by_oid = [{}, {}]
+ self._own_dumpers_by_oid = [True, True]
+
+ self._loaders = [{}, {}]
+ self._own_loaders = [True, True]
+
+ self.types = types or TypesRegistry()
+
+ # implement the AdaptContext protocol too
+ @property
+ def adapters(self) -> "AdaptersMap":
+ return self
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ return None
+
+ def register_dumper(
+ self, cls: Union[type, str, None], dumper: Type[Dumper]
+ ) -> None:
+ """
+ Configure the context to use `!dumper` to convert objects of type `!cls`.
+
+ If two dumpers with different `~Dumper.format` are registered for the
+ same type, the last one registered will be chosen when the query
+ doesn't specify a format (i.e. when the value is used with a ``%s``
+ "`~PyFormat.AUTO`" placeholder).
+
+ :param cls: The type to manage.
+ :param dumper: The dumper to register for `!cls`.
+
+ If `!cls` is specified as string it will be lazy-loaded, so that it
+ will be possible to register it without importing it before. In this
+ case it should be the fully qualified name of the object (e.g.
+ ``"uuid.UUID"``).
+
+ If `!cls` is None, only use the dumper when looking up using
+ `get_dumper_by_oid()`, which happens when we know the Postgres type to
+ adapt to, but not the Python type that will be adapted (e.g. in COPY
+ after using `~psycopg.Copy.set_types()`).
+
+ """
+ if not (cls is None or isinstance(cls, (str, type))):
+ raise TypeError(
+ f"dumpers should be registered on classes, got {cls} instead"
+ )
+
+ if _psycopg:
+ dumper = self._get_optimised(dumper)
+
+ # Register the dumper both as its format and as auto
+ # so that the last dumper registered is used in auto (%s) format
+ if cls:
+ for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO):
+ if not self._own_dumpers[fmt]:
+ self._dumpers[fmt] = self._dumpers[fmt].copy()
+ self._own_dumpers[fmt] = True
+
+ self._dumpers[fmt][cls] = dumper
+
+ # Register the dumper by oid, if the oid of the dumper is fixed
+ if dumper.oid:
+ if not self._own_dumpers_by_oid[dumper.format]:
+ self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[
+ dumper.format
+ ].copy()
+ self._own_dumpers_by_oid[dumper.format] = True
+
+ self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
+
+ def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
+ """
+ Configure the context to use `!loader` to convert data of oid `!oid`.
+
+ :param oid: The PostgreSQL OID or type name to manage.
+ :param loader: The loar to register for `!oid`.
+
+ If `oid` is specified as string, it refers to a type name, which is
+ looked up in the `types` registry. `
+
+ """
+ if isinstance(oid, str):
+ oid = self.types[oid].oid
+ if not isinstance(oid, int):
+ raise TypeError(f"loaders should be registered on oid, got {oid} instead")
+
+ if _psycopg:
+ loader = self._get_optimised(loader)
+
+ fmt = loader.format
+ if not self._own_loaders[fmt]:
+ self._loaders[fmt] = self._loaders[fmt].copy()
+ self._own_loaders[fmt] = True
+
+ self._loaders[fmt][oid] = loader
+
+ def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]:
+ """
+ Return the dumper class for the given type and format.
+
+ Raise `~psycopg.ProgrammingError` if a class is not available.
+
+ :param cls: The class to adapt.
+ :param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`,
+ use the last one of the dumpers registered on `!cls`.
+ """
+ try:
+ dmap = self._dumpers[format]
+ except KeyError:
+ raise ValueError(f"bad dumper format: {format}")
+
+ # Look for the right class, including looking at superclasses
+ for scls in cls.__mro__:
+ if scls in dmap:
+ return dmap[scls]
+
+ # If the adapter is not found, look for its name as a string
+ fqn = scls.__module__ + "." + scls.__qualname__
+ if fqn in dmap:
+ # Replace the class name with the class itself
+ d = dmap[scls] = dmap.pop(fqn)
+ return d
+
+ raise e.ProgrammingError(
+ f"cannot adapt type {cls.__name__!r} using placeholder '%{format}'"
+ f" (format: {PyFormat(format).name})"
+ )
+
+ def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]:
+ """
+ Return the dumper class for the given oid and format.
+
+ Raise `~psycopg.ProgrammingError` if a class is not available.
+
+ :param oid: The oid of the type to dump to.
+ :param format: The format to dump to.
+ """
+ try:
+ dmap = self._dumpers_by_oid[format]
+ except KeyError:
+ raise ValueError(f"bad dumper format: {format}")
+
+ try:
+ return dmap[oid]
+ except KeyError:
+ info = self.types.get(oid)
+ if info:
+ msg = (
+ f"cannot find a dumper for type {info.name} (oid {oid})"
+ f" format {pq.Format(format).name}"
+ )
+ else:
+ msg = (
+ f"cannot find a dumper for unknown type with oid {oid}"
+ f" format {pq.Format(format).name}"
+ )
+ raise e.ProgrammingError(msg)
+
+ def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
+ """
+ Return the loader class for the given oid and format.
+
+ Return `!None` if not found.
+
+ :param oid: The oid of the type to load.
+ :param format: The format to load from.
+ """
+ return self._loaders[format].get(oid)
+
+ @classmethod
+ def _get_optimised(self, cls: Type[RV]) -> Type[RV]:
+ """Return the optimised version of a Dumper or Loader class.
+
+ Return the input class itself if there is no optimised version.
+ """
+ try:
+ return self._optimised[cls]
+ except KeyError:
+ pass
+
+ # Check if the class comes from psycopg.types and there is a class
+ # with the same name in psycopg_c._psycopg.
+ from psycopg import types
+
+ if cls.__module__.startswith(types.__name__):
+ new = cast(Type[RV], getattr(_psycopg, cls.__name__, None))
+ if new:
+ self._optimised[cls] = new
+ return new
+
+ self._optimised[cls] = cls
+ return cls
+
+
+# Micro-optimization: copying these objects is faster than creating new dicts
+_dumpers_owned = dict.fromkeys(PyFormat, True)
+_dumpers_shared = dict.fromkeys(PyFormat, False)
diff --git a/psycopg/psycopg/_cmodule.py b/psycopg/psycopg/_cmodule.py
new file mode 100644
index 0000000..288ef1b
--- /dev/null
+++ b/psycopg/psycopg/_cmodule.py
@@ -0,0 +1,24 @@
+"""
+Simplify access to the _psycopg module
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from typing import Optional
+
+from . import pq
+
+__version__: Optional[str] = None
+
+# Note: "c" must the first attempt so that mypy associates the variable the
+# right module interface. It will not result Optional, but hey.
+if pq.__impl__ == "c":
+ from psycopg_c import _psycopg as _psycopg
+ from psycopg_c import __version__ as __version__ # noqa: F401
+elif pq.__impl__ == "binary":
+ from psycopg_binary import _psycopg as _psycopg # type: ignore
+ from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401
+elif pq.__impl__ == "python":
+ _psycopg = None # type: ignore
+else:
+ raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")
diff --git a/psycopg/psycopg/_column.py b/psycopg/psycopg/_column.py
new file mode 100644
index 0000000..9e4e735
--- /dev/null
+++ b/psycopg/psycopg/_column.py
@@ -0,0 +1,143 @@
+"""
+The Column object in Cursor.description
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
+from operator import attrgetter
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor
+
+
+class ColumnData(NamedTuple):
+ ftype: int
+ fmod: int
+ fsize: int
+
+
+class Column(Sequence[Any]):
+
+ __module__ = "psycopg"
+
+ def __init__(self, cursor: "BaseCursor[Any, Any]", index: int):
+ res = cursor.pgresult
+ assert res
+
+ fname = res.fname(index)
+ if fname:
+ self._name = fname.decode(cursor._encoding)
+ else:
+ # COPY_OUT results have columns but no name
+ self._name = f"column_{index + 1}"
+
+ self._data = ColumnData(
+ ftype=res.ftype(index),
+ fmod=res.fmod(index),
+ fsize=res.fsize(index),
+ )
+ self._type = cursor.adapters.types.get(self._data.ftype)
+
+ _attrs = tuple(
+ attrgetter(attr)
+ for attr in """
+ name type_code display_size internal_size precision scale null_ok
+ """.split()
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"<Column {self.name!r},"
+ f" type: {self._type_display()} (oid: {self.type_code})>"
+ )
+
+ def __len__(self) -> int:
+ return 7
+
+ def _type_display(self) -> str:
+ parts = []
+ parts.append(self._type.name if self._type else str(self.type_code))
+
+ mod1 = self.precision
+ if mod1 is None:
+ mod1 = self.display_size
+ if mod1:
+ parts.append(f"({mod1}")
+ if self.scale:
+ parts.append(f", {self.scale}")
+ parts.append(")")
+
+ if self._type and self.type_code == self._type.array_oid:
+ parts.append("[]")
+
+ return "".join(parts)
+
+ def __getitem__(self, index: Any) -> Any:
+ if isinstance(index, slice):
+ return tuple(getter(self) for getter in self._attrs[index])
+ else:
+ return self._attrs[index](self)
+
+ @property
+ def name(self) -> str:
+ """The name of the column."""
+ return self._name
+
+ @property
+ def type_code(self) -> int:
+ """The numeric OID of the column."""
+ return self._data.ftype
+
+ @property
+ def display_size(self) -> Optional[int]:
+ """The field size, for :sql:`varchar(n)`, None otherwise."""
+ if not self._type:
+ return None
+
+ if self._type.name in ("varchar", "char"):
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod - 4
+
+ return None
+
+ @property
+ def internal_size(self) -> Optional[int]:
+ """The internal field size for fixed-size types, None otherwise."""
+ fsize = self._data.fsize
+ return fsize if fsize >= 0 else None
+
+ @property
+ def precision(self) -> Optional[int]:
+ """The number of digits for fixed precision types."""
+ if not self._type:
+ return None
+
+ dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
+ if self._type.name == "numeric":
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod >> 16
+
+ elif self._type.name in dttypes:
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod & 0xFFFF
+
+ return None
+
+ @property
+ def scale(self) -> Optional[int]:
+ """The number of digits after the decimal point if available."""
+ if self._type and self._type.name == "numeric":
+ fmod = self._data.fmod - 4
+ if fmod >= 0:
+ return fmod & 0xFFFF
+
+ return None
+
+ @property
+ def null_ok(self) -> Optional[bool]:
+ """Always `!None`"""
+ return None
diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py
new file mode 100644
index 0000000..7dbae79
--- /dev/null
+++ b/psycopg/psycopg/_compat.py
@@ -0,0 +1,72 @@
+"""
+compatibility functions for different Python versions
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import sys
+import asyncio
+from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar
+
+# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it.
+# For this raisin it must be imported directly from typing_extension where used.
+# See https://github.com/microsoft/pyright/issues/4197
+from typing_extensions import TypeAlias
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+T = TypeVar("T")
+FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
+
+if sys.version_info >= (3, 8):
+ create_task = asyncio.create_task
+ from math import prod
+
+else:
+
+ def create_task(
+ coro: FutureT[T], name: Optional[str] = None
+ ) -> "asyncio.Future[T]":
+ return asyncio.create_task(coro)
+
+ from functools import reduce
+
+ def prod(seq: Sequence[int]) -> int:
+ return reduce(int.__mul__, seq, 1)
+
+
+if sys.version_info >= (3, 9):
+ from zoneinfo import ZoneInfo
+ from functools import cache
+ from collections import Counter, deque as Deque
+else:
+ from typing import Counter, Deque
+ from functools import lru_cache
+ from backports.zoneinfo import ZoneInfo
+
+ cache = lru_cache(maxsize=None)
+
+if sys.version_info >= (3, 10):
+ from typing import TypeGuard
+else:
+ from typing_extensions import TypeGuard
+
+if sys.version_info >= (3, 11):
+ from typing import LiteralString
+else:
+ from typing_extensions import LiteralString
+
+__all__ = [
+ "Counter",
+ "Deque",
+ "LiteralString",
+ "Protocol",
+ "TypeGuard",
+ "ZoneInfo",
+ "cache",
+ "create_task",
+ "prod",
+]
diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py
new file mode 100644
index 0000000..1e146ba
--- /dev/null
+++ b/psycopg/psycopg/_dns.py
@@ -0,0 +1,223 @@
+# type: ignore # dnspython is currently optional and mypy fails if missing
+"""
+DNS query support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import os
+import re
+import warnings
+from random import randint
+from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
+from typing import TYPE_CHECKING
+from collections import defaultdict
+
+try:
+ from dns.resolver import Resolver, Cache
+ from dns.asyncresolver import Resolver as AsyncResolver
+ from dns.exception import DNSException
+except ImportError:
+ raise ImportError(
+ "the module psycopg._dns requires the package 'dnspython' installed"
+ )
+
+from . import errors as e
+from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_
+
+if TYPE_CHECKING:
+ from dns.rdtypes.IN.SRV import SRV
+
+resolver = Resolver()
+resolver.cache = Cache()
+
+async_resolver = AsyncResolver()
+async_resolver.cache = Cache()
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ .. deprecated:: 3.1
+ The use of this function is not necessary anymore, because
+ `psycopg.AsyncConnection.connect()` performs non-blocking name
+ resolution automatically.
+ """
+ warnings.warn(
+ "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
+ DeprecationWarning,
+ )
+ return await resolve_hostaddr_async_(params)
+
+
+def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply SRV DNS lookup as defined in :RFC:`2782`."""
+ return Rfc2782Resolver().resolve(params)
+
+
+async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Async equivalent of `resolve_srv()`."""
+ return await Rfc2782Resolver().resolve_async(params)
+
+
+class HostPort(NamedTuple):
+ host: str
+ port: str
+ totry: bool = False
+ target: Optional[str] = None
+
+
+class Rfc2782Resolver:
+ """Implement SRV RR Resolution as per RFC 2782
+
+ The class is organised to minimise code duplication between the sync and
+ the async paths.
+ """
+
+ re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
+
+ def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(self._resolve_srv(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(await self._resolve_srv_async(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
+ """
+ Return the list of host, and for each host if SRV lookup must be tried.
+
+ Return an empty list if no lookup is requested.
+ """
+ # If hostaddr is defined don't do any resolution.
+ if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
+ return []
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",")
+
+ if len(ports_in) == 1:
+ # If only one port is specified, it applies to all the hosts.
+ ports_in *= len(hosts_in)
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+ )
+
+ out = []
+ srv_found = False
+ for host, port in zip(hosts_in, ports_in):
+ m = self.re_srv_rr.match(host)
+ if m or port.lower() == "srv":
+ srv_found = True
+ target = m.group("target") if m else None
+ hp = HostPort(host=host, port=port, totry=True, target=target)
+ else:
+ hp = HostPort(host=host, port=port)
+ out.append(hp)
+
+ return out if srv_found else []
+
+ def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = await async_resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ def _get_solved_entries(
+ self, hp: HostPort, entries: "Sequence[SRV]"
+ ) -> List[HostPort]:
+ if not entries:
+ # No SRV entry found. Delegate the libpq a QNAME=target lookup
+ if hp.target and hp.port.lower() != "srv":
+ return [HostPort(host=hp.target, port=hp.port)]
+ else:
+ return []
+
+ # If there is precisely one SRV RR, and its Target is "." (the root
+ # domain), abort.
+ if len(entries) == 1 and str(entries[0].target) == ".":
+ return []
+
+ return [
+ HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
+ for entry in self.sort_rfc2782(entries)
+ ]
+
+ def _return_params(
+ self, params: Dict[str, Any], hps: List[HostPort]
+ ) -> Dict[str, Any]:
+ if not hps:
+ # Nothing found, we ended up with an empty list
+ raise e.OperationalError("no host found after SRV RR lookup")
+
+ out = params.copy()
+ out["host"] = ",".join(hp.host for hp in hps)
+ out["port"] = ",".join(str(hp.port) for hp in hps)
+ return out
+
+ def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
+ """
+ Implement the priority/weight ordering defined in RFC 2782.
+ """
+ # Divide the entries by priority:
+ priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
+ out: "List[SRV]" = []
+ for entry in ans:
+ priorities[entry.priority].append(entry)
+
+ for pri, entries in sorted(priorities.items()):
+ if len(entries) == 1:
+ out.append(entries[0])
+ continue
+
+ entries.sort(key=lambda ent: ent.weight)
+ total_weight = sum(ent.weight for ent in entries)
+ while entries:
+ r = randint(0, total_weight)
+ csum = 0
+ for i, ent in enumerate(entries):
+ csum += ent.weight
+ if csum >= r:
+ break
+ out.append(ent)
+ total_weight -= ent.weight
+ del entries[i]
+
+ return out
diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py
new file mode 100644
index 0000000..c584b26
--- /dev/null
+++ b/psycopg/psycopg/_encodings.py
@@ -0,0 +1,170 @@
+"""
+Mappings between PostgreSQL and Python encodings.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import string
+import codecs
+from typing import Any, Dict, Optional, TYPE_CHECKING
+
+from .pq._enums import ConnStatus
+from .errors import NotSupportedError
+from ._compat import cache
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn
+ from .connection import BaseConnection
+
+OK = ConnStatus.OK
+
+
+_py_codecs = {
+ "BIG5": "big5",
+ "EUC_CN": "gb2312",
+ "EUC_JIS_2004": "euc_jis_2004",
+ "EUC_JP": "euc_jp",
+ "EUC_KR": "euc_kr",
+ # "EUC_TW": not available in Python
+ "GB18030": "gb18030",
+ "GBK": "gbk",
+ "ISO_8859_5": "iso8859-5",
+ "ISO_8859_6": "iso8859-6",
+ "ISO_8859_7": "iso8859-7",
+ "ISO_8859_8": "iso8859-8",
+ "JOHAB": "johab",
+ "KOI8R": "koi8-r",
+ "KOI8U": "koi8-u",
+ "LATIN1": "iso8859-1",
+ "LATIN10": "iso8859-16",
+ "LATIN2": "iso8859-2",
+ "LATIN3": "iso8859-3",
+ "LATIN4": "iso8859-4",
+ "LATIN5": "iso8859-9",
+ "LATIN6": "iso8859-10",
+ "LATIN7": "iso8859-13",
+ "LATIN8": "iso8859-14",
+ "LATIN9": "iso8859-15",
+ # "MULE_INTERNAL": not available in Python
+ "SHIFT_JIS_2004": "shift_jis_2004",
+ "SJIS": "shift_jis",
+ # this actually means no encoding, see PostgreSQL docs
+ # it is special-cased by the text loader.
+ "SQL_ASCII": "ascii",
+ "UHC": "cp949",
+ "UTF8": "utf-8",
+ "WIN1250": "cp1250",
+ "WIN1251": "cp1251",
+ "WIN1252": "cp1252",
+ "WIN1253": "cp1253",
+ "WIN1254": "cp1254",
+ "WIN1255": "cp1255",
+ "WIN1256": "cp1256",
+ "WIN1257": "cp1257",
+ "WIN1258": "cp1258",
+ "WIN866": "cp866",
+ "WIN874": "cp874",
+}
+
+py_codecs: Dict[bytes, str] = {}
+py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
+
+# Add an alias without underscore, for lenient lookups
+py_codecs.update(
+ (k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
+)
+
+pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
+
+
+def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str:
+ """
+ Return the Python encoding name of a psycopg connection.
+
+ Default to utf8 if the connection has no encoding info.
+ """
+ if not conn or conn.closed:
+ return "utf-8"
+
+ pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8"
+ return pg2pyenc(pgenc)
+
+
+def pgconn_encoding(pgconn: "PGconn") -> str:
+ """
+ Return the Python encoding name of a libpq connection.
+
+ Default to utf8 if the connection has no encoding info.
+ """
+ if pgconn.status != OK:
+ return "utf-8"
+
+ pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8"
+ return pg2pyenc(pgenc)
+
+
+def conninfo_encoding(conninfo: str) -> str:
+ """
+ Return the Python encoding name passed in a conninfo string. Default to utf8.
+
+ Because the input is likely to come from the user and not normalised by the
+ server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
+ """
+ from .conninfo import conninfo_to_dict
+
+ params = conninfo_to_dict(conninfo)
+ pgenc = params.get("client_encoding")
+ if pgenc:
+ try:
+ return pg2pyenc(pgenc.encode())
+ except NotSupportedError:
+ pass
+
+ return "utf-8"
+
+
+@cache
+def py2pgenc(name: str) -> bytes:
+ """Convert a Python encoding name to PostgreSQL encoding name.
+
+ Raise LookupError if the Python encoding is unknown.
+ """
+ return pg_codecs[codecs.lookup(name).name]
+
+
+@cache
+def pg2pyenc(name: bytes) -> str:
+ """Convert a Python encoding name to PostgreSQL encoding name.
+
+ Raise NotSupportedError if the PostgreSQL encoding is not supported by
+ Python.
+ """
+ try:
+ return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
+ except KeyError:
+ sname = name.decode("utf8", "replace")
+ raise NotSupportedError(f"codec not available in Python: {sname!r}")
+
+
+def _as_python_identifier(s: str, prefix: str = "f") -> str:
+ """
+ Reduce a string to a valid Python identifier.
+
+ Replace all non-valid chars with '_' and prefix the value with `!prefix` if
+ the first letter is an '_'.
+ """
+ if not s.isidentifier():
+ if s[0] in "1234567890":
+ s = prefix + s
+ if not s.isidentifier():
+ s = _re_clean.sub("_", s)
+ # namedtuple fields cannot start with underscore. So...
+ if s[0] == "_":
+ s = prefix + s
+ return s
+
+
+_re_clean = re.compile(
+ f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]"
+)
diff --git a/psycopg/psycopg/_enums.py b/psycopg/psycopg/_enums.py
new file mode 100644
index 0000000..a7cb78d
--- /dev/null
+++ b/psycopg/psycopg/_enums.py
@@ -0,0 +1,79 @@
+"""
+Enum values for psycopg
+
+These values are defined by us and are not necessarily dependent on
+libpq-defined enums.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import Enum, IntEnum
+from selectors import EVENT_READ, EVENT_WRITE
+
+from . import pq
+
+
+class Wait(IntEnum):
+ R = EVENT_READ
+ W = EVENT_WRITE
+ RW = EVENT_READ | EVENT_WRITE
+
+
+class Ready(IntEnum):
+ R = EVENT_READ
+ W = EVENT_WRITE
+ RW = EVENT_READ | EVENT_WRITE
+
+
+class PyFormat(str, Enum):
+ """
+ Enum representing the format wanted for a query argument.
+
+ The value `AUTO` allows psycopg to choose the best format for a certain
+ parameter.
+ """
+
+ __module__ = "psycopg.adapt"
+
+ AUTO = "s"
+ """Automatically chosen (``%s`` placeholder)."""
+ TEXT = "t"
+ """Text parameter (``%t`` placeholder)."""
+ BINARY = "b"
+ """Binary parameter (``%b`` placeholder)."""
+
+ @classmethod
+ def from_pq(cls, fmt: pq.Format) -> "PyFormat":
+ return _pg2py[fmt]
+
+ @classmethod
+ def as_pq(cls, fmt: "PyFormat") -> pq.Format:
+ return _py2pg[fmt]
+
+
+class IsolationLevel(IntEnum):
+ """
+ Enum representing the isolation level for a transaction.
+ """
+
+ __module__ = "psycopg"
+
+ READ_UNCOMMITTED = 1
+ """:sql:`READ UNCOMMITTED` isolation level."""
+ READ_COMMITTED = 2
+ """:sql:`READ COMMITTED` isolation level."""
+ REPEATABLE_READ = 3
+ """:sql:`REPEATABLE READ` isolation level."""
+ SERIALIZABLE = 4
+ """:sql:`SERIALIZABLE` isolation level."""
+
+
+_py2pg = {
+ PyFormat.TEXT: pq.Format.TEXT,
+ PyFormat.BINARY: pq.Format.BINARY,
+}
+
+_pg2py = {
+ pq.Format.TEXT: PyFormat.TEXT,
+ pq.Format.BINARY: PyFormat.BINARY,
+}
diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py
new file mode 100644
index 0000000..c818d86
--- /dev/null
+++ b/psycopg/psycopg/_pipeline.py
@@ -0,0 +1,288 @@
+"""
+commands pipeline management
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+from types import TracebackType
+from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from .abc import PipelineCommand, PQGen
+from ._compat import Deque
+from ._encodings import pgconn_encoding
+from ._preparing import Key, Prepare
+from .generators import pipeline_communicate, fetch_many, send
+
+if TYPE_CHECKING:
+ from .pq.abc import PGresult
+ from .cursor import BaseCursor
+ from .connection import BaseConnection, Connection
+ from .connection_async import AsyncConnection
+
+
+PendingResult: TypeAlias = Union[
+ None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
+]
+
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+BAD = pq.ConnStatus.BAD
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+logger = logging.getLogger("psycopg")
+
+
+class BasePipeline:
+
+ command_queue: Deque[PipelineCommand]
+ result_queue: Deque[PendingResult]
+ _is_supported: Optional[bool] = None
+
+ def __init__(self, conn: "BaseConnection[Any]") -> None:
+ self._conn = conn
+ self.pgconn = conn.pgconn
+ self.command_queue = Deque[PipelineCommand]()
+ self.result_queue = Deque[PendingResult]()
+ self.level = 0
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._conn.pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @property
+ def status(self) -> pq.PipelineStatus:
+ return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+ @classmethod
+ def is_supported(cls) -> bool:
+ """Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
+ if BasePipeline._is_supported is None:
+ BasePipeline._is_supported = not cls._not_supported_reason()
+ return BasePipeline._is_supported
+
+ @classmethod
+ def _not_supported_reason(cls) -> str:
+ """Return the reason why the pipeline mode is not supported.
+
+ Return an empty string if pipeline mode is supported.
+ """
+ # Support only depends on the libpq functions available in the pq
+ # wrapper, not on the database version.
+ if pq.version() < 140000:
+ return (
+ f"libpq too old {pq.version()};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ if pq.__build_version__ < 140000:
+ return (
+ f"libpq too old: module built for {pq.__build_version__};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ return ""
+
+ def _enter_gen(self) -> PQGen[None]:
+ if not self.is_supported():
+ raise e.NotSupportedError(
+ f"pipeline mode not supported: {self._not_supported_reason()}"
+ )
+ if self.level == 0:
+ self.pgconn.enter_pipeline_mode()
+ elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
+ # Nested pipeline case.
+ # Transaction might be ACTIVE when the pipeline uses an "implicit
+ # transaction", typically in autocommit mode. But when entering a
+ # Psycopg transaction(), we expect the IDLE state. By sync()-ing,
+ # we make sure all previous commands are completed and the
+ # transaction gets back to IDLE.
+ yield from self._sync_gen()
+ self.level += 1
+
+ def _exit(self, exc: Optional[BaseException]) -> None:
+ self.level -= 1
+ if self.level == 0 and self.pgconn.status != BAD:
+ try:
+ self.pgconn.exit_pipeline_mode()
+ except e.OperationalError as exc2:
+ # Notice that this error might be pretty irrecoverable. It
+ # happens on COPY, for instance: even if sync succeeds, exiting
+ # fails with "cannot exit pipeline mode with uncollected results"
+ if exc:
+ logger.warning("error ignored exiting %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+
+ def _sync_gen(self) -> PQGen[None]:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ yield from self._fetch_gen(flush=False)
+
+ def _exit_gen(self) -> PQGen[None]:
+ """
+ Exit current pipeline by sending a Sync and fetch back all remaining results.
+ """
+ try:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ finally:
+ # No need to force flush since we emitted a sync just before.
+ yield from self._fetch_gen(flush=False)
+
+ def _communicate_gen(self) -> PQGen[None]:
+ """Communicate with pipeline to send commands and possibly fetch
+ results, which are then processed.
+ """
+ fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
+ to_process = [(self.result_queue.popleft(), results) for results in fetched]
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
+ """Fetch available results from the connection and process them with
+ pipeline queued items.
+
+ If 'flush' is True, a PQsendFlushRequest() is issued in order to make
+ sure results can be fetched. Otherwise, the caller may emit a
+ PQpipelineSync() call to ensure the output buffer gets flushed before
+ fetching.
+ """
+ if not self.result_queue:
+ return
+
+ if flush:
+ self.pgconn.send_flush_request()
+ yield from send(self.pgconn)
+
+ to_process = []
+ while self.result_queue:
+ results = yield from fetch_many(self.pgconn)
+ if not results:
+ # No more results to fetch, but there may still be pending
+ # commands.
+ break
+ queued = self.result_queue.popleft()
+ to_process.append((queued, results))
+
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _process_results(
+ self, queued: PendingResult, results: List["PGresult"]
+ ) -> None:
+ """Process a results set fetched from the current pipeline.
+
+ This matches 'results' with its respective element in the pipeline
+ queue. For commands (None value in the pipeline queue), results are
+ checked directly. For prepare statement creation requests, update the
+ cache. Otherwise, results are attached to their respective cursor.
+ """
+ if queued is None:
+ (result,) = results
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ elif result.status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ else:
+ cursor, prepinfo = queued
+ cursor._set_results_from_pipeline(results)
+ if prepinfo:
+ key, prep, name = prepinfo
+ # Update the prepare state of the query.
+ cursor._conn._prepared.validate(key, prep, name, results)
+
+ def _enqueue_sync(self) -> None:
+ """Enqueue a PQpipelineSync() command."""
+ self.command_queue.append(self.pgconn.pipeline_sync)
+ self.result_queue.append(None)
+
+
+class Pipeline(BasePipeline):
+ """Handler for connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "Connection[Any]"
+ _Self = TypeVar("_Self", bound="Pipeline")
+
+ def __init__(self, conn: "Connection[Any]") -> None:
+ super().__init__(conn)
+
+ def sync(self) -> None:
+ """Sync the pipeline, send any pending command and receive and process
+ all available results.
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def __enter__(self: _Self) -> _Self:
+ with self._conn.lock:
+ self._conn.wait(self._enter_gen())
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)
+
+
+class AsyncPipeline(BasePipeline):
+ """Handler for async connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "AsyncConnection[Any]"
+ _Self = TypeVar("_Self", bound="AsyncPipeline")
+
+ def __init__(self, conn: "AsyncConnection[Any]") -> None:
+ super().__init__(conn)
+
+ async def sync(self) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def __aenter__(self: _Self) -> _Self:
+ async with self._conn.lock:
+ await self._conn.wait(self._enter_gen())
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)
diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py
new file mode 100644
index 0000000..f60c0cb
--- /dev/null
+++ b/psycopg/psycopg/_preparing.py
@@ -0,0 +1,198 @@
+"""
+Support for prepared statements
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import IntEnum, auto
+from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
+from collections import OrderedDict
+from typing_extensions import TypeAlias
+
+from . import pq
+from ._compat import Deque
+from ._queries import PostgresQuery
+
+if TYPE_CHECKING:
+ from .pq.abc import PGresult
+
+Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+
+
+class Prepare(IntEnum):
+ NO = auto()
+ YES = auto()
+ SHOULD = auto()
+
+
+class PrepareManager:
+ # Number of times a query is executed before it is prepared.
+ prepare_threshold: Optional[int] = 5
+
+ # Maximum number of prepared statements on the connection.
+ prepared_max: int = 100
+
+ def __init__(self) -> None:
+ # Map (query, types) to the number of times the query was seen.
+ self._counts: OrderedDict[Key, int] = OrderedDict()
+
+ # Map (query, types) to the name of the statement if prepared.
+ self._names: OrderedDict[Key, bytes] = OrderedDict()
+
+ # Counter to generate prepared statements names
+ self._prepared_idx = 0
+
+ self._maint_commands = Deque[bytes]()
+
+ @staticmethod
+ def key(query: PostgresQuery) -> Key:
+ return (query.query, query.types)
+
+ def get(
+ self, query: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ """
+ Check if a query is prepared, tell back whether to prepare it.
+ """
+ if prepare is False or self.prepare_threshold is None:
+ # The user doesn't want this query to be prepared
+ return Prepare.NO, b""
+
+ key = self.key(query)
+ name = self._names.get(key)
+ if name:
+ # The query was already prepared in this session
+ return Prepare.YES, name
+
+ count = self._counts.get(key, 0)
+ if count >= self.prepare_threshold or prepare:
+ # The query has been executed enough times and needs to be prepared
+ name = f"_pg3_{self._prepared_idx}".encode()
+ self._prepared_idx += 1
+ return Prepare.SHOULD, name
+ else:
+ # The query is not to be prepared yet
+ return Prepare.NO, b""
+
+ def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
+ """Check if we need to discard our entire state: it should happen on
+ rollback or on dropping objects, because the same object may get
+ recreated and postgres would fail internal lookups.
+ """
+ if self._names or prep == Prepare.SHOULD:
+ for result in results:
+ if result.status != COMMAND_OK:
+ continue
+ cmdstat = result.command_status
+ if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
+ return self.clear()
+ return False
+
+ @staticmethod
+ def _check_results(results: Sequence["PGresult"]) -> bool:
+ """Return False if 'results' are invalid for prepared statement cache."""
+ if len(results) != 1:
+ # We cannot prepare a multiple statement
+ return False
+
+ status = results[0].status
+ if COMMAND_OK != status != TUPLES_OK:
+ # We don't prepare failed queries or other weird results
+ return False
+
+ return True
+
+ def _rotate(self) -> None:
+ """Evict an old value from the cache.
+
+ If it was prepared, deallocate it. Do it only once: if the cache was
+ resized, deallocate gradually.
+ """
+ if len(self._counts) > self.prepared_max:
+ self._counts.popitem(last=False)
+
+ if len(self._names) > self.prepared_max:
+ name = self._names.popitem(last=False)[1]
+ self._maint_commands.append(b"DEALLOCATE " + name)
+
+ def maybe_add_to_cache(
+ self, query: PostgresQuery, prep: Prepare, name: bytes
+ ) -> Optional[Key]:
+ """Handle 'query' for possible addition to the cache.
+
+ If a new entry has been added, return its key. Return None otherwise
+ (meaning the query is already in cache or cache is not enabled).
+
+ Note: This method is only called in pipeline mode.
+ """
+ # don't do anything if prepared statements are disabled
+ if self.prepare_threshold is None:
+ return None
+
+ key = self.key(query)
+ if key in self._counts:
+ if prep is Prepare.SHOULD:
+ del self._counts[key]
+ self._names[key] = name
+ else:
+ self._counts[key] += 1
+ self._counts.move_to_end(key)
+ return None
+
+ elif key in self._names:
+ self._names.move_to_end(key)
+ return None
+
+ else:
+ if prep is Prepare.SHOULD:
+ self._names[key] = name
+ else:
+ self._counts[key] = 1
+ return key
+
+ def validate(
+ self,
+ key: Key,
+ prep: Prepare,
+ name: bytes,
+ results: Sequence["PGresult"],
+ ) -> None:
+ """Validate cached entry with 'key' by checking query 'results'.
+
+ Possibly return a command to perform maintenance on database side.
+
+ Note: this method is only called in pipeline mode.
+ """
+ if self._should_discard(prep, results):
+ return
+
+ if not self._check_results(results):
+ self._names.pop(key, None)
+ self._counts.pop(key, None)
+ else:
+ self._rotate()
+
+ def clear(self) -> bool:
+ """Clear the cache of the maintenance commands.
+
+ Clear the internal state and prepare a command to clear the state of
+ the server.
+ """
+ self._counts.clear()
+ if self._names:
+ self._names.clear()
+ self._maint_commands.clear()
+ self._maint_commands.append(b"DEALLOCATE ALL")
+ return True
+ else:
+ return False
+
+ def get_maintenance_commands(self) -> Iterator[bytes]:
+ """
+ Iterate over the commands needed to align the server state to our state
+ """
+ while self._maint_commands:
+ yield self._maint_commands.popleft()
diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py
new file mode 100644
index 0000000..2a7554c
--- /dev/null
+++ b/psycopg/psycopg/_queries.py
@@ -0,0 +1,375 @@
+"""
+Utility module to manipulate queries
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
+from typing import Sequence, Tuple, Union, TYPE_CHECKING
+from functools import lru_cache
+
+from . import pq
+from . import errors as e
+from .sql import Composable
+from .abc import Buffer, Query, Params
+from ._enums import PyFormat
+from ._encodings import conn_encoding
+
+if TYPE_CHECKING:
+ from .abc import Transformer
+
+
+class QueryPart(NamedTuple):
+ pre: bytes
+ item: Union[int, str]
+ format: PyFormat
+
+
+class PostgresQuery:
+ """
+ Helper to convert a Python query and parameters into Postgres format.
+ """
+
+ __slots__ = """
+ query params types formats
+ _tx _want_formats _parts _encoding _order
+ """.split()
+
+ def __init__(self, transformer: "Transformer"):
+ self._tx = transformer
+
+ self.params: Optional[Sequence[Optional[Buffer]]] = None
+ # these are tuples so they can be used as keys e.g. in prepared stmts
+ self.types: Tuple[int, ...] = ()
+
+ # The format requested by the user and the ones to really pass Postgres
+ self._want_formats: Optional[List[PyFormat]] = None
+ self.formats: Optional[Sequence[pq.Format]] = None
+
+ self._encoding = conn_encoding(transformer.connection)
+ self._parts: List[QueryPart]
+ self.query = b""
+ self._order: Optional[List[str]] = None
+
+ def convert(self, query: Query, vars: Optional[Params]) -> None:
+ """
+ Set up the query and parameters to convert.
+
+ The results of this function can be obtained accessing the object
+ attributes (`query`, `params`, `types`, `formats`).
+ """
+ if isinstance(query, str):
+ bquery = query.encode(self._encoding)
+ elif isinstance(query, Composable):
+ bquery = query.as_bytes(self._tx)
+ else:
+ bquery = query
+
+ if vars is not None:
+ (
+ self.query,
+ self._want_formats,
+ self._order,
+ self._parts,
+ ) = _query2pg(bquery, self._encoding)
+ else:
+ self.query = bquery
+ self._want_formats = self._order = None
+
+ self.dump(vars)
+
+ def dump(self, vars: Optional[Params]) -> None:
+ """
+ Process a new set of variables on the query processed by `convert()`.
+
+ This method updates `params` and `types`.
+ """
+ if vars is not None:
+ params = _validate_and_reorder_params(self._parts, vars, self._order)
+ assert self._want_formats is not None
+ self.params = self._tx.dump_sequence(params, self._want_formats)
+ self.types = self._tx.types or ()
+ self.formats = self._tx.formats
+ else:
+ self.params = None
+ self.types = ()
+ self.formats = None
+
+
+class PostgresClientQuery(PostgresQuery):
+ """
+ PostgresQuery subclass merging query and arguments client-side.
+ """
+
+ __slots__ = ("template",)
+
+ def convert(self, query: Query, vars: Optional[Params]) -> None:
+ """
+ Set up the query and parameters to convert.
+
+ The results of this function can be obtained accessing the object
+ attributes (`query`, `params`, `types`, `formats`).
+ """
+ if isinstance(query, str):
+ bquery = query.encode(self._encoding)
+ elif isinstance(query, Composable):
+ bquery = query.as_bytes(self._tx)
+ else:
+ bquery = query
+
+ if vars is not None:
+ (self.template, self._order, self._parts) = _query2pg_client(
+ bquery, self._encoding
+ )
+ else:
+ self.query = bquery
+ self._order = None
+
+ self.dump(vars)
+
+ def dump(self, vars: Optional[Params]) -> None:
+ """
+ Process a new set of variables on the query processed by `convert()`.
+
+ This method updates `params` and `types`.
+ """
+ if vars is not None:
+ params = _validate_and_reorder_params(self._parts, vars, self._order)
+ self.params = tuple(
+ self._tx.as_literal(p) if p is not None else b"NULL" for p in params
+ )
+ self.query = self.template % self.params
+ else:
+ self.params = None
+
+
+@lru_cache()
+def _query2pg(
+ query: bytes, encoding: str
+) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
+ """
+ Convert Python query and params into something Postgres understands.
+
+ - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
+ format (``$1``, ``$2``)
+ - placeholders can be %s, %t, or %b (auto, text or binary)
+ - return ``query`` (bytes), ``formats`` (list of formats) ``order``
+ (sequence of names used in the query, in the position they appear)
+ ``parts`` (splits of queries and placeholders).
+ """
+ parts = _split_query(query, encoding)
+ order: Optional[List[str]] = None
+ chunks: List[bytes] = []
+ formats = []
+
+ if isinstance(parts[0].item, int):
+ for part in parts[:-1]:
+ assert isinstance(part.item, int)
+ chunks.append(part.pre)
+ chunks.append(b"$%d" % (part.item + 1))
+ formats.append(part.format)
+
+ elif isinstance(parts[0].item, str):
+ seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+ order = []
+ for part in parts[:-1]:
+ assert isinstance(part.item, str)
+ chunks.append(part.pre)
+ if part.item not in seen:
+ ph = b"$%d" % (len(seen) + 1)
+ seen[part.item] = (ph, part.format)
+ order.append(part.item)
+ chunks.append(ph)
+ formats.append(part.format)
+ else:
+ if seen[part.item][1] != part.format:
+ raise e.ProgrammingError(
+ f"placeholder '{part.item}' cannot have different formats"
+ )
+ chunks.append(seen[part.item][0])
+
+ # last part
+ chunks.append(parts[-1].pre)
+
+ return b"".join(chunks), formats, order, parts
+
+
+@lru_cache()
+def _query2pg_client(
+ query: bytes, encoding: str
+) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
+ """
+ Convert Python query and params into a template to perform client-side binding
+ """
+ parts = _split_query(query, encoding, collapse_double_percent=False)
+ order: Optional[List[str]] = None
+ chunks: List[bytes] = []
+
+ if isinstance(parts[0].item, int):
+ for part in parts[:-1]:
+ assert isinstance(part.item, int)
+ chunks.append(part.pre)
+ chunks.append(b"%s")
+
+ elif isinstance(parts[0].item, str):
+ seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+ order = []
+ for part in parts[:-1]:
+ assert isinstance(part.item, str)
+ chunks.append(part.pre)
+ if part.item not in seen:
+ ph = b"%s"
+ seen[part.item] = (ph, part.format)
+ order.append(part.item)
+ chunks.append(ph)
+ else:
+ chunks.append(seen[part.item][0])
+ order.append(part.item)
+
+ # last part
+ chunks.append(parts[-1].pre)
+
+ return b"".join(chunks), order, parts
+
+
+def _validate_and_reorder_params(
+ parts: List[QueryPart], vars: Params, order: Optional[List[str]]
+) -> Sequence[Any]:
+ """
+ Verify the compatibility between a query and a set of params.
+ """
+ # Try concrete types, then abstract types
+ t = type(vars)
+ if t is list or t is tuple:
+ sequence = True
+ elif t is dict:
+ sequence = False
+ elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+ sequence = True
+ elif isinstance(vars, Mapping):
+ sequence = False
+ else:
+ raise TypeError(
+ "query parameters should be a sequence or a mapping,"
+ f" got {type(vars).__name__}"
+ )
+
+ if sequence:
+ if len(vars) != len(parts) - 1:
+ raise e.ProgrammingError(
+ f"the query has {len(parts) - 1} placeholders but"
+ f" {len(vars)} parameters were passed"
+ )
+ if vars and not isinstance(parts[0].item, int):
+ raise TypeError("named placeholders require a mapping of parameters")
+ return vars # type: ignore[return-value]
+
+ else:
+ if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
+ raise TypeError(
+ "positional placeholders (%s) require a sequence of parameters"
+ )
+ try:
+ return [vars[item] for item in order or ()] # type: ignore[call-overload]
+ except KeyError:
+ raise e.ProgrammingError(
+ "query parameter missing:"
+ f" {', '.join(sorted(i for i in order or () if i not in vars))}"
+ )
+
+
+_re_placeholder = re.compile(
+ rb"""(?x)
+ % # a literal %
+ (?:
+ (?:
+ \( ([^)]+) \) # or a name in (braces)
+ . # followed by a format
+ )
+ |
+ (?:.) # or any char, really
+ )
+ """
+)
+
+
+def _split_query(
+ query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
+) -> List[QueryPart]:
+ parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
+ cur = 0
+
+ # pairs [(fragment, match], with the last match None
+ m = None
+ for m in _re_placeholder.finditer(query):
+ pre = query[cur : m.span(0)[0]]
+ parts.append((pre, m))
+ cur = m.span(0)[1]
+ if m:
+ parts.append((query[cur:], None))
+ else:
+ parts.append((query, None))
+
+ rv = []
+
+ # drop the "%%", validate
+ i = 0
+ phtype = None
+ while i < len(parts):
+ pre, m = parts[i]
+ if m is None:
+ # last part
+ rv.append(QueryPart(pre, 0, PyFormat.AUTO))
+ break
+
+ ph = m.group(0)
+ if ph == b"%%":
+ # unescape '%%' to '%' if necessary, then merge the parts
+ if collapse_double_percent:
+ ph = b"%"
+ pre1, m1 = parts[i + 1]
+ parts[i + 1] = (pre + ph + pre1, m1)
+ del parts[i]
+ continue
+
+ if ph == b"%(":
+ raise e.ProgrammingError(
+ "incomplete placeholder:"
+ f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
+ )
+ elif ph == b"% ":
+ # explicit messasge for a typical error
+ raise e.ProgrammingError(
+ "incomplete placeholder: '%'; if you want to use '%' as an"
+ " operator you can double it up, i.e. use '%%'"
+ )
+ elif ph[-1:] not in b"sbt":
+ raise e.ProgrammingError(
+ "only '%s', '%b', '%t' are allowed as placeholders, got"
+ f" '{m.group(0).decode(encoding)}'"
+ )
+
+ # Index or name
+ item: Union[int, str]
+ item = m.group(1).decode(encoding) if m.group(1) else i
+
+ if not phtype:
+ phtype = type(item)
+ elif phtype is not type(item):
+ raise e.ProgrammingError(
+ "positional and named placeholders cannot be mixed"
+ )
+
+ format = _ph_to_fmt[ph[-1:]]
+ rv.append(QueryPart(pre, item, format))
+ i += 1
+
+ return rv
+
+
+_ph_to_fmt = {
+ b"s": PyFormat.AUTO,
+ b"t": PyFormat.TEXT,
+ b"b": PyFormat.BINARY,
+}
diff --git a/psycopg/psycopg/_struct.py b/psycopg/psycopg/_struct.py
new file mode 100644
index 0000000..28a6084
--- /dev/null
+++ b/psycopg/psycopg/_struct.py
@@ -0,0 +1,57 @@
+"""
+Utility functions to deal with binary structs.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import struct
+from typing import Callable, cast, Optional, Tuple
+from typing_extensions import TypeAlias
+
+from .abc import Buffer
+from . import errors as e
+from ._compat import Protocol
+
+PackInt: TypeAlias = Callable[[int], bytes]
+UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
+PackFloat: TypeAlias = Callable[[float], bytes]
+UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
+
+
+class UnpackLen(Protocol):
+ def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
+ ...
+
+
+pack_int2 = cast(PackInt, struct.Struct("!h").pack)
+pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
+pack_int4 = cast(PackInt, struct.Struct("!i").pack)
+pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
+pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
+pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
+
+unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
+unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
+unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
+unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
+unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
+unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
+unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
+
+_struct_len = struct.Struct("!i")
+pack_len = cast(Callable[[int], bytes], _struct_len.pack)
+unpack_len = cast(UnpackLen, _struct_len.unpack_from)
+
+
+def pack_float4_bug_304(x: float) -> bytes:
+ raise e.InterfaceError(
+ "cannot dump Float4: Python affected by bug #304. Note that the psycopg-c"
+ " and psycopg-binary packages are not affected by this issue."
+ " See https://github.com/psycopg/psycopg/issues/304"
+ )
+
+
+# If issue #304 is detected, raise an error instead of dumping wrong data.
+if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"):
+ pack_float4 = pack_float4_bug_304
diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py
new file mode 100644
index 0000000..3528188
--- /dev/null
+++ b/psycopg/psycopg/_tpc.py
@@ -0,0 +1,116 @@
+"""
+psycopg two-phase commit support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+import datetime as dt
+from base64 import b64encode, b64decode
+from typing import Optional, Union
+from dataclasses import dataclass, replace
+
+_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
+
+
+@dataclass(frozen=True)
+class Xid:
+ """A two-phase commit transaction identifier.
+
+ The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`,
+ `bqual`).
+
+ """
+
+ format_id: Optional[int]
+ gtrid: str
+ bqual: Optional[str]
+ prepared: Optional[dt.datetime] = None
+ owner: Optional[str] = None
+ database: Optional[str] = None
+
+ @classmethod
+ def from_string(cls, s: str) -> "Xid":
+ """Try to parse an XA triple from the string.
+
+ This may fail for several reasons. In such case return an unparsed Xid.
+ """
+ try:
+ return cls._parse_string(s)
+ except Exception:
+ return Xid(None, s, None)
+
+ def __str__(self) -> str:
+ return self._as_tid()
+
+ def __len__(self) -> int:
+ return 3
+
+ def __getitem__(self, index: int) -> Union[int, str, None]:
+ return (self.format_id, self.gtrid, self.bqual)[index]
+
+ @classmethod
+ def _parse_string(cls, s: str) -> "Xid":
+ m = _re_xid.match(s)
+ if not m:
+ raise ValueError("bad Xid format")
+
+ format_id = int(m.group(1))
+ gtrid = b64decode(m.group(2)).decode()
+ bqual = b64decode(m.group(3)).decode()
+ return cls.from_parts(format_id, gtrid, bqual)
+
+ @classmethod
+ def from_parts(
+ cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
+ ) -> "Xid":
+ if format_id is not None:
+ if bqual is None:
+ raise TypeError("if format_id is specified, bqual must be too")
+ if not 0 <= format_id < 0x80000000:
+ raise ValueError("format_id must be a non-negative 32-bit integer")
+ if len(bqual) > 64:
+ raise ValueError("bqual must be not longer than 64 chars")
+ if len(gtrid) > 64:
+ raise ValueError("gtrid must be not longer than 64 chars")
+
+ elif bqual is None:
+ raise TypeError("if format_id is None, bqual must be None too")
+
+ return Xid(format_id, gtrid, bqual)
+
+ def _as_tid(self) -> str:
+ """
+ Return the PostgreSQL transaction_id for this XA xid.
+
+ PostgreSQL wants just a string, while the DBAPI supports the XA
+ standard and thus a triple. We use the same conversion algorithm
+ implemented by JDBC in order to allow some form of interoperation.
+
+ see also: the pgjdbc implementation
+ http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
+ postgresql/xa/RecoveredXid.java?rev=1.2
+ """
+ if self.format_id is None or self.bqual is None:
+ # Unparsed xid: return the gtrid.
+ return self.gtrid
+
+ # XA xid: mash together the components.
+ egtrid = b64encode(self.gtrid.encode()).decode()
+ ebqual = b64encode(self.bqual.encode()).decode()
+
+ return f"{self.format_id}_{egtrid}_{ebqual}"
+
+ @classmethod
+ def _get_recover_query(cls) -> str:
+ return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
+
+ @classmethod
+ def _from_record(
+ cls, gid: str, prepared: dt.datetime, owner: str, database: str
+ ) -> "Xid":
+ xid = Xid.from_string(gid)
+ return replace(xid, prepared=prepared, owner=owner, database=database)
+
+
+Xid.__module__ = "psycopg"
diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py
new file mode 100644
index 0000000..19bd6ae
--- /dev/null
+++ b/psycopg/psycopg/_transform.py
@@ -0,0 +1,350 @@
+"""
+Helper object to transform values between Python and PostgreSQL
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, List, Optional, Sequence, Tuple
+from typing import DefaultDict, TYPE_CHECKING
+from collections import defaultdict
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import postgres
+from . import errors as e
+from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
+from .rows import Row, RowMaker
+from .postgres import INVALID_OID, TEXT_OID
+from ._encodings import pgconn_encoding
+
+if TYPE_CHECKING:
+ from .abc import Dumper, Loader
+ from .adapt import AdaptersMap
+ from .pq.abc import PGresult
+ from .connection import BaseConnection
+
+DumperCache: TypeAlias = Dict[DumperKey, "Dumper"]
+OidDumperCache: TypeAlias = Dict[int, "Dumper"]
+LoaderCache: TypeAlias = Dict[int, "Loader"]
+
+TEXT = pq.Format.TEXT
+PY_TEXT = PyFormat.TEXT
+
+
+class Transformer(AdaptContext):
+ """
+ An object that can adapt efficiently between Python and PostgreSQL.
+
+ The life cycle of the object is the query, so it is assumed that attributes
+ such as the server version or the connection encoding will not change. The
+ object have its state so adapting several values of the same type can be
+ optimised.
+
+ """
+
+ __module__ = "psycopg.adapt"
+
+ __slots__ = """
+ types formats
+ _conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
+ _oid_dumpers _oid_types _row_dumpers _row_loaders
+ """.split()
+
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+
+ _adapters: "AdaptersMap"
+ _pgresult: Optional["PGresult"]
+ _none_oid: int
+
+ def __init__(self, context: Optional[AdaptContext] = None):
+ self._pgresult = self.types = self.formats = None
+
+ # WARNING: don't store context, or you'll create a loop with the Cursor
+ if context:
+ self._adapters = context.adapters
+ self._conn = context.connection
+ else:
+ self._adapters = postgres.adapters
+ self._conn = None
+
+ # mapping fmt, class -> Dumper instance
+ self._dumpers: DefaultDict[PyFormat, DumperCache]
+ self._dumpers = defaultdict(dict)
+
+ # mapping fmt, oid -> Dumper instance
+ # Not often used, so create it only if needed.
+ self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
+ self._oid_dumpers = None
+
+ # mapping fmt, oid -> Loader instance
+ self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
+
+ self._row_dumpers: Optional[List["Dumper"]] = None
+
+ # sequence of load functions from value to python
+ # the length of the result columns
+ self._row_loaders: List[LoadFunc] = []
+
+ # mapping oid -> type sql representation
+ self._oid_types: Dict[int, bytes] = {}
+
+ self._encoding = ""
+
+ @classmethod
+ def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
+ """
+ Return a Transformer from an AdaptContext.
+
+ If the context is a Transformer instance, just return it.
+ """
+ if isinstance(context, Transformer):
+ return context
+ else:
+ return cls(context)
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ return self._conn
+
+ @property
+ def encoding(self) -> str:
+ if not self._encoding:
+ conn = self.connection
+ self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
+ return self._encoding
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ return self._adapters
+
+ @property
+ def pgresult(self) -> Optional["PGresult"]:
+ return self._pgresult
+
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None,
+ ) -> None:
+ self._pgresult = result
+
+ if not result:
+ self._nfields = self._ntuples = 0
+ if set_loaders:
+ self._row_loaders = []
+ return
+
+ self._ntuples = result.ntuples
+ nf = self._nfields = result.nfields
+
+ if not set_loaders:
+ return
+
+ if not nf:
+ self._row_loaders = []
+ return
+
+ fmt: pq.Format
+ fmt = result.fformat(0) if format is None else format # type: ignore
+ self._row_loaders = [
+ self.get_loader(result.ftype(i), fmt).load for i in range(nf)
+ ]
+
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
+ self.types = tuple(types)
+ self.formats = [format] * len(types)
+
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_loaders = [self.get_loader(oid, format).load for oid in types]
+
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[Buffer]]:
+ nparams = len(params)
+ out: List[Optional[Buffer]] = [None] * nparams
+
+ # If we have dumpers, it means set_dumper_types had been called, in
+ # which case self.types and self.formats are set to sequences of the
+ # right size.
+ if self._row_dumpers:
+ for i in range(nparams):
+ param = params[i]
+ if param is not None:
+ out[i] = self._row_dumpers[i].dump(param)
+ return out
+
+ types = [self._get_none_oid()] * nparams
+ pqformats = [TEXT] * nparams
+
+ for i in range(nparams):
+ param = params[i]
+ if param is None:
+ continue
+ dumper = self.get_dumper(param, formats[i])
+ out[i] = dumper.dump(param)
+ types[i] = dumper.oid
+ pqformats[i] = dumper.format
+
+ self.types = tuple(types)
+ self.formats = pqformats
+
+ return out
+
+ def as_literal(self, obj: Any) -> bytes:
+ dumper = self.get_dumper(obj, PY_TEXT)
+ rv = dumper.quote(obj)
+ # If the result is quoted, and the oid not unknown or text,
+ # add an explicit type cast.
+ # Check the last char because the first one might be 'E'.
+ oid = dumper.oid
+ if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
+ try:
+ type_sql = self._oid_types[oid]
+ except KeyError:
+ ti = self.adapters.types.get(oid)
+ if ti:
+ if oid < 8192:
+ # builtin: prefer "timestamptz" to "timestamp with time zone"
+ type_sql = ti.name.encode(self.encoding)
+ else:
+ type_sql = ti.regtype.encode(self.encoding)
+ if oid == ti.array_oid:
+ type_sql += b"[]"
+ else:
+ type_sql = b""
+ self._oid_types[oid] = type_sql
+
+ if type_sql:
+ rv = b"%s::%s" % (rv, type_sql)
+
+ if not isinstance(rv, bytes):
+ rv = bytes(rv)
+ return rv
+
+ def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
+ """
+ Return a Dumper instance to dump `!obj`.
+ """
+ # Normally, the type of the object dictates how to dump it
+ key = type(obj)
+
+ # Reuse an existing Dumper class for objects of the same type
+ cache = self._dumpers[format]
+ try:
+ dumper = cache[key]
+ except KeyError:
+ # If it's the first time we see this type, look for a dumper
+ # configured for it.
+ dcls = self.adapters.get_dumper(key, format)
+ cache[key] = dumper = dcls(key, self)
+
+ # Check if the dumper requires an upgrade to handle this specific value
+ key1 = dumper.get_key(obj, format)
+ if key1 is key:
+ return dumper
+
+ # If it does, ask the dumper to create its own upgraded version
+ try:
+ return cache[key1]
+ except KeyError:
+ dumper = cache[key1] = dumper.upgrade(obj, format)
+ return dumper
+
+ def _get_none_oid(self) -> int:
+ try:
+ return self._none_oid
+ except AttributeError:
+ pass
+
+ try:
+ rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
+ except KeyError:
+ raise e.InterfaceError("None dumper not found")
+
+ return rv
+
+ def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
+ """
+ Return a Dumper to dump an object to the type with given oid.
+ """
+ if not self._oid_dumpers:
+ self._oid_dumpers = ({}, {})
+
+ # Reuse an existing Dumper class for objects of the same type
+ cache = self._oid_dumpers[format]
+ try:
+ return cache[oid]
+ except KeyError:
+ # If it's the first time we see this type, look for a dumper
+ # configured for it.
+ dcls = self.adapters.get_dumper_by_oid(oid, format)
+ cache[oid] = dumper = dcls(NoneType, self)
+
+ return dumper
+
+ def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
+ res = self._pgresult
+ if not res:
+ raise e.InterfaceError("result not set")
+
+ if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
+ raise e.InterfaceError(
+ f"rows must be included between 0 and {self._ntuples}"
+ )
+
+ records = []
+ for row in range(row0, row1):
+ record: List[Any] = [None] * self._nfields
+ for col in range(self._nfields):
+ val = res.get_value(row, col)
+ if val is not None:
+ record[col] = self._row_loaders[col](val)
+ records.append(make_row(record))
+
+ return records
+
+ def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
+ res = self._pgresult
+ if not res:
+ return None
+
+ if not 0 <= row < self._ntuples:
+ return None
+
+ record: List[Any] = [None] * self._nfields
+ for col in range(self._nfields):
+ val = res.get_value(row, col)
+ if val is not None:
+ record[col] = self._row_loaders[col](val)
+
+ return make_row(record)
+
+ def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
+ if len(self._row_loaders) != len(record):
+ raise e.ProgrammingError(
+ f"cannot load sequence of {len(record)} items:"
+ f" {len(self._row_loaders)} loaders registered"
+ )
+
+ return tuple(
+ (self._row_loaders[i](val) if val is not None else None)
+ for i, val in enumerate(record)
+ )
+
+ def get_loader(self, oid: int, format: pq.Format) -> "Loader":
+ try:
+ return self._loaders[format][oid]
+ except KeyError:
+ pass
+
+ loader_cls = self._adapters.get_loader(oid, format)
+ if not loader_cls:
+ loader_cls = self._adapters.get_loader(INVALID_OID, format)
+ if not loader_cls:
+ raise e.InterfaceError("unknown oid loader not found")
+ loader = self._loaders[format][oid] = loader_cls(oid, self)
+ return loader
diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py
new file mode 100644
index 0000000..2f1a24d
--- /dev/null
+++ b/psycopg/psycopg/_typeinfo.py
@@ -0,0 +1,461 @@
+"""
+Information about PostgreSQL types
+
+These types allow to read information from the system catalog and provide
+information to the adapters if needed.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+from enum import Enum
+from typing import Any, Dict, Iterator, Optional, overload
+from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import errors as e
+from .abc import AdaptContext
+from .rows import dict_row
+
+if TYPE_CHECKING:
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+ from .sql import Identifier
+
+T = TypeVar("T", bound="TypeInfo")
+RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
+
+
+class TypeInfo:
+ """
+ Hold information about a PostgreSQL base type.
+ """
+
+ __module__ = "psycopg.types"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ delimiter: str = ",",
+ ):
+ self.name = name
+ self.oid = oid
+ self.array_oid = array_oid
+ self.regtype = regtype or name
+ self.delimiter = delimiter
+
+ def __repr__(self) -> str:
+ return (
+ f"<{self.__class__.__qualname__}:"
+ f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
+ )
+
+ @overload
+ @classmethod
+ def fetch(
+ cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
+ ) -> Optional[T]:
+ ...
+
+ @overload
+ @classmethod
+ async def fetch(
+ cls: Type[T],
+ conn: "AsyncConnection[Any]",
+ name: Union[str, "Identifier"],
+ ) -> Optional[T]:
+ ...
+
+ @classmethod
+ def fetch(
+ cls: Type[T],
+ conn: "Union[Connection[Any], AsyncConnection[Any]]",
+ name: Union[str, "Identifier"],
+ ) -> Any:
+ """Query a system catalog to read information about a type."""
+ from .sql import Composable
+ from .connection_async import AsyncConnection
+
+ if isinstance(name, Composable):
+ name = name.as_string(conn)
+
+ if isinstance(conn, AsyncConnection):
+ return cls._fetch_async(conn, name)
+
+ # This might result in a nested transaction. What we want is to leave
+ # the function with the connection in the state we found (either idle
+ # or intrans)
+ try:
+ with conn.transaction():
+ with conn.cursor(binary=True, row_factory=dict_row) as cur:
+ cur.execute(cls._get_info_query(conn), {"name": name})
+ recs = cur.fetchall()
+ except e.UndefinedObject:
+ return None
+
+ return cls._from_records(name, recs)
+
+ @classmethod
+ async def _fetch_async(
+ cls: Type[T], conn: "AsyncConnection[Any]", name: str
+ ) -> Optional[T]:
+ """
+ Query a system catalog to read information about a type.
+
+ Similar to `fetch()` but can use an asynchronous connection.
+ """
+ try:
+ async with conn.transaction():
+ async with conn.cursor(binary=True, row_factory=dict_row) as cur:
+ await cur.execute(cls._get_info_query(conn), {"name": name})
+ recs = await cur.fetchall()
+ except e.UndefinedObject:
+ return None
+
+ return cls._from_records(name, recs)
+
+ @classmethod
+ def _from_records(
+ cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
+ ) -> Optional[T]:
+ if len(recs) == 1:
+ return cls(**recs[0])
+ elif not recs:
+ return None
+ else:
+ raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
+
+ def register(self, context: Optional[AdaptContext] = None) -> None:
+ """
+ Register the type information, globally or in the specified `!context`.
+ """
+ if context:
+ types = context.adapters.types
+ else:
+ from . import postgres
+
+ types = postgres.types
+
+ types.add(self)
+
+ if self.array_oid:
+ from .types.array import register_array
+
+ register_array(self, context)
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT
+ typname AS name, oid, typarray AS array_oid,
+ oid::regtype::text AS regtype, typdelim AS delimiter
+FROM pg_type t
+WHERE t.oid = %(name)s::regtype
+ORDER BY t.oid
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ """Method called by the `!registry` when the object is added there."""
+ pass
+
+
+class RangeInfo(TypeInfo):
+ """Manage information about a range type."""
+
+ __module__ = "psycopg.types.range"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ subtype_oid: int,
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.subtype_oid = subtype_oid
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ r.rngsubtype AS subtype_oid
+FROM pg_type t
+JOIN pg_range r ON t.oid = r.rngtypid
+WHERE t.oid = %(name)s::regtype
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ # Map ranges subtypes to info
+ registry._registry[RangeInfo, self.subtype_oid] = self
+
+
+class MultirangeInfo(TypeInfo):
+ """Manage information about a multirange type."""
+
+ # TODO: expose to multirange module once added
+ # __module__ = "psycopg.types.multirange"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ range_oid: int,
+ subtype_oid: int,
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.range_oid = range_oid
+ self.subtype_oid = subtype_oid
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ if conn.info.server_version < 140000:
+ raise e.NotSupportedError(
+ "multirange types are only available from PostgreSQL 14"
+ )
+ return """\
+SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
+FROM pg_type t
+JOIN pg_range r ON t.oid = r.rngmultitypid
+WHERE t.oid = %(name)s::regtype
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ # Map multiranges ranges and subtypes to info
+ registry._registry[MultirangeInfo, self.range_oid] = self
+ registry._registry[MultirangeInfo, self.subtype_oid] = self
+
+
+class CompositeInfo(TypeInfo):
+ """Manage information about a composite type."""
+
+ __module__ = "psycopg.types.composite"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ field_names: Sequence[str],
+ field_types: Sequence[int],
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.field_names = field_names
+ self.field_types = field_types
+ # Will be set by register() if the `factory` is a type
+ self.python_type: Optional[type] = None
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT
+ t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ coalesce(a.fnames, '{}') AS field_names,
+ coalesce(a.ftypes, '{}') AS field_types
+FROM pg_type t
+LEFT JOIN (
+ SELECT
+ attrelid,
+ array_agg(attname) AS fnames,
+ array_agg(atttypid) AS ftypes
+ FROM (
+ SELECT a.attrelid, a.attname, a.atttypid
+ FROM pg_attribute a
+ JOIN pg_type t ON t.typrelid = a.attrelid
+ WHERE t.oid = %(name)s::regtype
+ AND a.attnum > 0
+ AND NOT a.attisdropped
+ ORDER BY a.attnum
+ ) x
+ GROUP BY attrelid
+) a ON a.attrelid = t.typrelid
+WHERE t.oid = %(name)s::regtype
+"""
+
+
+class EnumInfo(TypeInfo):
+ """Manage information about an enum type."""
+
+ __module__ = "psycopg.types.enum"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ labels: Sequence[str],
+ ):
+ super().__init__(name, oid, array_oid)
+ self.labels = labels
+ # Will be set by register_enum()
+ self.enum: Optional[Type[Enum]] = None
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT name, oid, array_oid, array_agg(label) AS labels
+FROM (
+ SELECT
+ t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ e.enumlabel AS label
+ FROM pg_type t
+ LEFT JOIN pg_enum e
+ ON e.enumtypid = t.oid
+ WHERE t.oid = %(name)s::regtype
+ ORDER BY e.enumsortorder
+) x
+GROUP BY name, oid, array_oid
+"""
+
+
+class TypesRegistry:
+ """
+ Container for the information about types in a database.
+ """
+
+ __module__ = "psycopg.types"
+
+ def __init__(self, template: Optional["TypesRegistry"] = None):
+ self._registry: Dict[RegistryKey, TypeInfo]
+
+ # Make a shallow copy: it will become a proper copy if the registry
+ # is edited.
+ if template:
+ self._registry = template._registry
+ self._own_state = False
+ template._own_state = False
+ else:
+ self.clear()
+
+ def clear(self) -> None:
+ self._registry = {}
+ self._own_state = True
+
+ def add(self, info: TypeInfo) -> None:
+ self._ensure_own_state()
+ if info.oid:
+ self._registry[info.oid] = info
+ if info.array_oid:
+ self._registry[info.array_oid] = info
+ self._registry[info.name] = info
+
+ if info.regtype and info.regtype not in self._registry:
+ self._registry[info.regtype] = info
+
+ # Allow info to customise further their relation with the registry
+ info._added(self)
+
+ def __iter__(self) -> Iterator[TypeInfo]:
+ seen = set()
+ for t in self._registry.values():
+ if id(t) not in seen:
+ seen.add(id(t))
+ yield t
+
+ @overload
+ def __getitem__(self, key: Union[str, int]) -> TypeInfo:
+ ...
+
+ @overload
+ def __getitem__(self, key: Tuple[Type[T], int]) -> T:
+ ...
+
+ def __getitem__(self, key: RegistryKey) -> TypeInfo:
+ """
+ Return info about a type, specified by name or oid
+
+ :param key: the name or oid of the type to look for.
+
+ Raise KeyError if not found.
+ """
+ if isinstance(key, str):
+ if key.endswith("[]"):
+ key = key[:-2]
+ elif not isinstance(key, (int, tuple)):
+ raise TypeError(f"the key must be an oid or a name, got {type(key)}")
+ try:
+ return self._registry[key]
+ except KeyError:
+ raise KeyError(f"couldn't find the type {key!r} in the types registry")
+
+ @overload
+ def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
+ ...
+
+ @overload
+ def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
+ ...
+
+ def get(self, key: RegistryKey) -> Optional[TypeInfo]:
+ """
+ Return info about a type, specified by name or oid
+
+ :param key: the name or oid of the type to look for.
+
+ Unlike `__getitem__`, return None if not found.
+ """
+ try:
+ return self[key]
+ except KeyError:
+ return None
+
+ def get_oid(self, name: str) -> int:
+ """
+ Return the oid of a PostgreSQL type by name.
+
+ :param key: the name of the type to look for.
+
+ Return the array oid if the type ends with "``[]``"
+
+ Raise KeyError if the name is unknown.
+ """
+ t = self[name]
+ if name.endswith("[]"):
+ return t.array_oid
+ else:
+ return t.oid
+
+ def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
+ """
+ Return info about a `TypeInfo` subclass by its element name or oid.
+
+ :param cls: the subtype of `!TypeInfo` to look for. Currently
+ supported are `~psycopg.types.range.RangeInfo` and
+ `~psycopg.types.multirange.MultirangeInfo`.
+ :param subtype: The name or OID of the subtype of the element to look for.
+ :return: The `!TypeInfo` object of class `!cls` whose subtype is
+ `!subtype`. `!None` if the element or its range are not found.
+ """
+ try:
+ info = self[subtype]
+ except KeyError:
+ return None
+ return self.get((cls, info.oid))
+
+ def _ensure_own_state(self) -> None:
+ # Time to write! so, copy.
+ if not self._own_state:
+ self._registry = self._registry.copy()
+ self._own_state = True
diff --git a/psycopg/psycopg/_tz.py b/psycopg/psycopg/_tz.py
new file mode 100644
index 0000000..813ed62
--- /dev/null
+++ b/psycopg/psycopg/_tz.py
@@ -0,0 +1,44 @@
+"""
+Timezone utility functions.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import Dict, Optional, Union
+from datetime import timezone, tzinfo
+
+from .pq.abc import PGconn
+from ._compat import ZoneInfo
+
+logger = logging.getLogger("psycopg")
+
+_timezones: Dict[Union[None, bytes], tzinfo] = {
+ None: timezone.utc,
+ b"UTC": timezone.utc,
+}
+
+
+def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
+ """Return the Python timezone info of the connection's timezone."""
+ tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
+ try:
+ return _timezones[tzname]
+ except KeyError:
+ sname = tzname.decode() if tzname else "UTC"
+ try:
+ zi: tzinfo = ZoneInfo(sname)
+ except (KeyError, OSError):
+ logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
+ zi = timezone.utc
+ except Exception as ex:
+ logger.warning(
+ "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
+ sname,
+ type(ex).__name__,
+ ex,
+ )
+ zi = timezone.utc
+
+ _timezones[tzname] = zi
+ return zi
diff --git a/psycopg/psycopg/_wrappers.py b/psycopg/psycopg/_wrappers.py
new file mode 100644
index 0000000..f861741
--- /dev/null
+++ b/psycopg/psycopg/_wrappers.py
@@ -0,0 +1,137 @@
+"""
+Wrappers for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Wrappers to force numbers to be cast as specific PostgreSQL types
+
+# These types are implemented here but exposed by `psycopg.types.numeric`.
+# They are defined here to avoid a circular import.
+_MODULE = "psycopg.types.numeric"
+
+
+class Int2(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int2":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Int4(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int4":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Int8(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int8":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class IntNumeric(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "IntNumeric":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float4(float):
+ """
+ Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float4":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float8(float):
+ """
+ Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float8":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Oid(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`oid`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Oid":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py
new file mode 100644
index 0000000..80c8fbf
--- /dev/null
+++ b/psycopg/psycopg/abc.py
@@ -0,0 +1,266 @@
+"""
+Protocol objects representing different implementations of the same classes.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Generator, Mapping
+from typing import List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import pq
+from ._enums import PyFormat as PyFormat
+from ._compat import Protocol, LiteralString
+
+if TYPE_CHECKING:
+ from . import sql
+ from .rows import Row, RowMaker
+ from .pq.abc import PGresult
+ from .waiting import Wait, Ready
+ from .connection import BaseConnection
+ from ._adapters_map import AdaptersMap
+
+NoneType: type = type(None)
+
+# An object implementing the buffer protocol
+Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
+
+Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
+Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
+ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
+PipelineCommand: TypeAlias = Callable[[], None]
+DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
+
+# Waiting protocol types
+
+RV = TypeVar("RV")
+
+PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
+"""Generator for processes where the connection file number can change.
+
+This can happen in connection and reset, but not in normal querying.
+"""
+
+PQGen: TypeAlias = Generator["Wait", "Ready", RV]
+"""Generator for processes where the connection file number won't change.
+"""
+
+
+class WaitFunc(Protocol):
+ """
+ Wait on the connection which generated `PQgen` and return its final result.
+ """
+
+ def __call__(
+ self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
+ ) -> RV:
+ ...
+
+
+# Adaptation types
+
+DumpFunc: TypeAlias = Callable[[Any], Buffer]
+LoadFunc: TypeAlias = Callable[[Buffer], Any]
+
+
+class AdaptContext(Protocol):
+ """
+ A context describing how types are adapted.
+
+ Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`,
+ `~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`.
+
+ Note that this is a `~typing.Protocol`, so objects implementing
+ `!AdaptContext` don't need to explicitly inherit from this class.
+
+ """
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ """The adapters configuration that this object uses."""
+ ...
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ """The connection used by this object, if available.
+
+ :rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None`
+ """
+ ...
+
+
+class Dumper(Protocol):
+ """
+ Convert Python objects of type `!cls` to PostgreSQL representation.
+ """
+
+ format: pq.Format
+ """
+ The format that this class `dump()` method produces,
+ `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
+
+ This is a class attribute.
+ """
+
+ oid: int
+ """The oid to pass to the server, if known; 0 otherwise (class attribute)."""
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ ...
+
+ def dump(self, obj: Any) -> Buffer:
+ """Convert the object `!obj` to PostgreSQL representation.
+
+ :param obj: the object to convert.
+ """
+ ...
+
+ def quote(self, obj: Any) -> Buffer:
+ """Convert the object `!obj` to escaped representation.
+
+ :param obj: the object to convert.
+ """
+ ...
+
+ def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
+ """Return an alternative key to upgrade the dumper to represent `!obj`.
+
+ :param obj: The object to convert
+ :param format: The format to convert to
+
+ Normally the type of the object is all it takes to define how to dump
+ the object to the database. For instance, a Python `~datetime.date` can
+ be simply converted into a PostgreSQL :sql:`date`.
+
+ In a few cases, just the type is not enough. For example:
+
+ - A Python `~datetime.datetime` could be represented as a
+ :sql:`timestamptz` or a :sql:`timestamp`, according to whether it
+ specifies a `!tzinfo` or not.
+
+ - A Python int could be stored as several Postgres types: int2, int4,
+ int8, numeric. If a type too small is used, it may result in an
+ overflow. If a type too large is used, PostgreSQL may not want to
+ cast it to a smaller type.
+
+ - Python lists should be dumped according to the type they contain to
+ convert them to e.g. array of strings, array of ints (and which
+ size of int?...)
+
+ In these cases, a dumper can implement `!get_key()` and return a new
+ class, or sequence of classes, that can be used to identify the same
+ dumper again. If the mechanism is not needed, the method should return
+ the same `!cls` object passed in the constructor.
+
+ If a dumper implements `get_key()` it should also implement
+ `upgrade()`.
+
+ """
+ ...
+
+ def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ """Return a new dumper to manage `!obj`.
+
+ :param obj: The object to convert
+ :param format: The format to convert to
+
+ Once `Transformer.get_dumper()` has been notified by `get_key()` that
+ this Dumper class cannot handle `!obj` itself, it will invoke
+ `!upgrade()`, which should return a new `Dumper` instance, which will
+ be reused for every objects for which `!get_key()` returns the same
+ result.
+ """
+ ...
+
+
+class Loader(Protocol):
+ """
+ Convert PostgreSQL values with type OID `!oid` to Python objects.
+ """
+
+ format: pq.Format
+ """
+ The format that this class `load()` method can convert,
+ `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
+
+ This is a class attribute.
+ """
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ ...
+
+ def load(self, data: Buffer) -> Any:
+ """
+ Convert the data returned by the database into a Python object.
+
+ :param data: the data to convert.
+ """
+ ...
+
+
+class Transformer(Protocol):
+
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+
+ def __init__(self, context: Optional[AdaptContext] = None):
+ ...
+
+ @classmethod
+ def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
+ ...
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ ...
+
+ @property
+ def encoding(self) -> str:
+ ...
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ ...
+
+ @property
+ def pgresult(self) -> Optional["PGresult"]:
+ ...
+
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None
+ ) -> None:
+ ...
+
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+ ...
+
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+ ...
+
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[Buffer]]:
+ ...
+
+ def as_literal(self, obj: Any) -> bytes:
+ ...
+
+ def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
+ ...
+
+ def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
+ ...
+
+ def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
+ ...
+
+ def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
+ ...
+
+ def get_loader(self, oid: int, format: pq.Format) -> Loader:
+ ...
diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py
new file mode 100644
index 0000000..7ec4a55
--- /dev/null
+++ b/psycopg/psycopg/adapt.py
@@ -0,0 +1,162 @@
+"""
+Entry point into the adaptation system.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Type, TYPE_CHECKING
+
+from . import pq, abc
+from . import _adapters_map
+from ._enums import PyFormat as PyFormat
+from ._cmodule import _psycopg
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection
+
+AdaptersMap = _adapters_map.AdaptersMap
+Buffer = abc.Buffer
+
+ORD_BS = ord("\\")
+
+
+class Dumper(abc.Dumper, ABC):
+ """
+ Convert Python object of the type `!cls` to PostgreSQL representation.
+ """
+
+ oid: int = 0
+ """The oid to pass to the server, if known."""
+
+ format: pq.Format = pq.Format.TEXT
+ """The format of the data dumped."""
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ self.cls = cls
+ self.connection: Optional["BaseConnection[Any]"] = (
+ context.connection if context else None
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"<{type(self).__module__}.{type(self).__qualname__}"
+ f" (oid={self.oid}) at 0x{id(self):x}>"
+ )
+
+ @abstractmethod
+ def dump(self, obj: Any) -> Buffer:
+ ...
+
+ def quote(self, obj: Any) -> Buffer:
+ """
+ By default return the `dump()` value quoted and sanitised, so
+ that the result can be used to build a SQL string. This works well
+ for most types and you won't likely have to implement this method in a
+ subclass.
+ """
+ value = self.dump(obj)
+
+ if self.connection:
+ esc = pq.Escaping(self.connection.pgconn)
+ # escaping and quoting
+ return esc.escape_literal(value)
+
+ # This path is taken when quote is asked without a connection,
+ # usually it means by psycopg.sql.quote() or by
+ # 'Composible.as_string(None)'. Most often than not this is done by
+ # someone generating a SQL file to consume elsewhere.
+
+ # No quoting, only quote escaping, random bs escaping. See further.
+ esc = pq.Escaping()
+ out = esc.escape_string(value)
+
+ # b"\\" in memoryview doesn't work so search for the ascii value
+ if ORD_BS not in out:
+ # If the string has no backslash, the result is correct and we
+ # don't need to bother with standard_conforming_strings.
+ return b"'" + out + b"'"
+
+ # The libpq has a crazy behaviour: PQescapeString uses the last
+ # standard_conforming_strings setting seen on a connection. This
+ # means that backslashes might be escaped or might not.
+ #
+ # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+ # if scs is off, '\\' raises a warning and '\' is an error.
+ #
+ # Check what the libpq does, and if it doesn't escape the backslash
+ # let's do it on our own. Never mind the race condition.
+ rv: bytes = b" E'" + out + b"'"
+ if esc.escape_string(b"\\") == b"\\":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
+
+ def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey:
+ """
+ Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
+ `~psycopg.abc.Dumper` protocol. Look at its definition for details.
+
+ This implementation returns the `!cls` passed in the constructor.
+ Subclasses needing to specialise the PostgreSQL type according to the
+ *value* of the object dumped (not only according to to its type)
+ should override this class.
+
+ """
+ return self.cls
+
+ def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ """
+ Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
+ `~psycopg.abc.Dumper` protocol. Look at its definition for details.
+
+ This implementation just returns `!self`. If a subclass implements
+ `get_key()` it should probably override `!upgrade()` too.
+ """
+ return self
+
+
+class Loader(abc.Loader, ABC):
+ """
+ Convert PostgreSQL values with type OID `!oid` to Python objects.
+ """
+
+ format: pq.Format = pq.Format.TEXT
+ """The format of the data loaded."""
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ self.oid = oid
+ self.connection: Optional["BaseConnection[Any]"] = (
+ context.connection if context else None
+ )
+
+ @abstractmethod
+ def load(self, data: Buffer) -> Any:
+ """Convert a PostgreSQL value to a Python object."""
+ ...
+
+
+Transformer: Type["abc.Transformer"]
+
+# Override it with fast object if available
+if _psycopg:
+ Transformer = _psycopg.Transformer
+else:
+ from . import _transform
+
+ Transformer = _transform.Transformer
+
+
+class RecursiveDumper(Dumper):
+ """Dumper with a transformer to help dumping recursive types."""
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ super().__init__(cls, context)
+ self._tx = Transformer.from_context(context)
+
+
+class RecursiveLoader(Loader):
+ """Loader with a transformer to help loading recursive types."""
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer.from_context(context)
diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py
new file mode 100644
index 0000000..6271ec5
--- /dev/null
+++ b/psycopg/psycopg/client_cursor.py
@@ -0,0 +1,95 @@
+"""
+psycopg client-side binding cursors
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from typing import Optional, Tuple, TYPE_CHECKING
+from functools import partial
+
+from ._queries import PostgresQuery, PostgresClientQuery
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import ConnectionType, Query, Params
+from .rows import Row
+from .cursor import BaseCursor, Cursor
+from ._preparing import Prepare
+from .cursor_async import AsyncCursor
+
+if TYPE_CHECKING:
+ from typing import Any # noqa: F401
+ from .connection import Connection # noqa: F401
+ from .connection_async import AsyncConnection # noqa: F401
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+
+class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
+ def mogrify(self, query: Query, params: Optional[Params] = None) -> str:
+ """
+ Return the query and parameters merged.
+
+ Parameters are adapted and merged to the query the same way that
+ `!execute()` would do.
+
+ """
+ self._tx = adapt.Transformer(self)
+ pgq = self._convert_query(query, params)
+ return pgq.query.decode(self._tx.encoding)
+
+ def _execute_send(
+ self,
+ query: PostgresQuery,
+ *,
+ force_extended: bool = False,
+ binary: Optional[bool] = None,
+ ) -> None:
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ if fmt == BINARY:
+ raise e.NotSupportedError(
+ "client-side cursors don't support binary results"
+ )
+
+ self._query = query
+
+ if self._conn._pipeline:
+ # In pipeline mode always use PQsendQueryParams - see #314
+ # Multiple statements in the same query are not allowed anyway.
+ self._conn._pipeline.command_queue.append(
+ partial(self._pgconn.send_query_params, query.query, None)
+ )
+ elif force_extended:
+ self._pgconn.send_query_params(query.query, None)
+ else:
+ # If we can, let's use simple query protocol,
+ # as it can execute more than one statement in a single query.
+ self._pgconn.send_query(query.query)
+
+ def _convert_query(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PostgresQuery:
+ pgq = PostgresClientQuery(self._tx)
+ pgq.convert(query, params)
+ return pgq
+
+ def _get_prepared(
+ self, pgq: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ return (Prepare.NO, b"")
+
+
+class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]):
+ __module__ = "psycopg"
+
+
+class AsyncClientCursor(
+ ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
+):
+ __module__ = "psycopg"
diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py
new file mode 100644
index 0000000..78ad577
--- /dev/null
+++ b/psycopg/psycopg/connection.py
@@ -0,0 +1,1031 @@
+"""
+psycopg connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+import threading
+from types import TracebackType
+from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
+from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
+from typing import overload, TYPE_CHECKING
+from weakref import ref, ReferenceType
+from warnings import warn
+from functools import partial
+from contextlib import contextmanager
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from . import waiting
+from . import postgres
+from .abc import AdaptContext, ConnectionType, Params, Query, RV
+from .abc import PQGen, PQGenConn
+from .sql import Composable, SQL
+from ._tpc import Xid
+from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
+from .adapt import AdaptersMap
+from ._enums import IsolationLevel
+from .cursor import Cursor
+from ._compat import LiteralString
+from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from ._pipeline import BasePipeline, Pipeline
+from .generators import notifies, connect, execute
+from ._encodings import pgconn_encoding
+from ._preparing import PrepareManager
+from .transaction import Transaction
+from .server_cursor import ServerCursor
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn, PGresult
+ from psycopg_pool.base import BasePool
+
+
+# Row Type variable for Cursor (when it needs to be distinguished from the
+# connection's one)
+CursorRow = TypeVar("CursorRow")
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+OK = pq.ConnStatus.OK
+BAD = pq.ConnStatus.BAD
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+logger = logging.getLogger("psycopg")
+
+
+class Notify(NamedTuple):
+ """An asynchronous notification received from the database."""
+
+ channel: str
+ """The name of the channel on which the notification was received."""
+
+ payload: str
+ """The message attached to the notification."""
+
+ pid: int
+ """The PID of the backend process which sent the notification."""
+
+
+Notify.__module__ = "psycopg"
+
+NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
+NotifyHandler: TypeAlias = Callable[[Notify], None]
+
+
+class BaseConnection(Generic[Row]):
+ """
+ Base class for different types of connections.
+
+ Share common functionalities such as access to the wrapped PGconn, but
+ allow different interfaces (sync/async).
+ """
+
+ # DBAPI2 exposed exceptions
+ Warning = e.Warning
+ Error = e.Error
+ InterfaceError = e.InterfaceError
+ DatabaseError = e.DatabaseError
+ DataError = e.DataError
+ OperationalError = e.OperationalError
+ IntegrityError = e.IntegrityError
+ InternalError = e.InternalError
+ ProgrammingError = e.ProgrammingError
+ NotSupportedError = e.NotSupportedError
+
+ # Enums useful for the connection
+ ConnStatus = pq.ConnStatus
+ TransactionStatus = pq.TransactionStatus
+
+ def __init__(self, pgconn: "PGconn"):
+ self.pgconn = pgconn
+ self._autocommit = False
+
+ # None, but set to a copy of the global adapters map as soon as requested.
+ self._adapters: Optional[AdaptersMap] = None
+
+ self._notice_handlers: List[NoticeHandler] = []
+ self._notify_handlers: List[NotifyHandler] = []
+
+ # Number of transaction blocks currently entered
+ self._num_transactions = 0
+
+ self._closed = False # closed by an explicit close()
+ self._prepared: PrepareManager = PrepareManager()
+ self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared
+
+ wself = ref(self)
+ pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
+ pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
+
+ # Attribute is only set if the connection is from a pool so we can tell
+ # apart a connection in the pool too (when _pool = None)
+ self._pool: Optional["BasePool[Any]"]
+
+ self._pipeline: Optional[BasePipeline] = None
+
+ # Time after which the connection should be closed
+ self._expire_at: float
+
+ self._isolation_level: Optional[IsolationLevel] = None
+ self._read_only: Optional[bool] = None
+ self._deferrable: Optional[bool] = None
+ self._begin_statement = b""
+
+ def __del__(self) -> None:
+ # If fails on connection we might not have this attribute yet
+ if not hasattr(self, "pgconn"):
+ return
+
+ # Connection correctly closed
+ if self.closed:
+ return
+
+ # Connection in a pool so terminating with the program is normal
+ if hasattr(self, "_pool"):
+ return
+
+ warn(
+ f"connection {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the connection",
+ ResourceWarning,
+ )
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self.pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @property
+ def closed(self) -> bool:
+ """`!True` if the connection is closed."""
+ return self.pgconn.status == BAD
+
+ @property
+ def broken(self) -> bool:
+ """
+ `!True` if the connection was interrupted.
+
+ A broken connection is always `closed`, but wasn't closed in a clean
+ way, such as using `close()` or a `!with` block.
+ """
+ return self.pgconn.status == BAD and not self._closed
+
+ @property
+ def autocommit(self) -> bool:
+ """The autocommit state of the connection."""
+ return self._autocommit
+
+ @autocommit.setter
+ def autocommit(self, value: bool) -> None:
+ self._set_autocommit(value)
+
+ def _set_autocommit(self, value: bool) -> None:
+ raise NotImplementedError
+
+ def _set_autocommit_gen(self, value: bool) -> PQGen[None]:
+ yield from self._check_intrans_gen("autocommit")
+ self._autocommit = bool(value)
+
+ @property
+ def isolation_level(self) -> Optional[IsolationLevel]:
+ """
+ The isolation level of the new transactions started on the connection.
+ """
+ return self._isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ self._set_isolation_level(value)
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ raise NotImplementedError
+
+ def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]:
+ yield from self._check_intrans_gen("isolation_level")
+ self._isolation_level = IsolationLevel(value) if value is not None else None
+ self._begin_statement = b""
+
+ @property
+ def read_only(self) -> Optional[bool]:
+ """
+ The read-only state of the new transactions started on the connection.
+ """
+ return self._read_only
+
+ @read_only.setter
+ def read_only(self, value: Optional[bool]) -> None:
+ self._set_read_only(value)
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ raise NotImplementedError
+
+ def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]:
+ yield from self._check_intrans_gen("read_only")
+ self._read_only = bool(value)
+ self._begin_statement = b""
+
+ @property
+ def deferrable(self) -> Optional[bool]:
+ """
+ The deferrable state of the new transactions started on the connection.
+ """
+ return self._deferrable
+
+ @deferrable.setter
+ def deferrable(self, value: Optional[bool]) -> None:
+ self._set_deferrable(value)
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ raise NotImplementedError
+
+ def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]:
+ yield from self._check_intrans_gen("deferrable")
+ self._deferrable = bool(value)
+ self._begin_statement = b""
+
+ def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
+ # Raise an exception if we are in a transaction
+ status = self.pgconn.transaction_status
+ if status == IDLE and self._pipeline:
+ yield from self._pipeline._sync_gen()
+ status = self.pgconn.transaction_status
+ if status != IDLE:
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ f"can't change {attribute!r} now: "
+ "connection.transaction() context in progress"
+ )
+ else:
+ raise e.ProgrammingError(
+ f"can't change {attribute!r} now: "
+ "connection in transaction status "
+ f"{pq.TransactionStatus(status).name}"
+ )
+
+ @property
+ def info(self) -> ConnectionInfo:
+ """A `ConnectionInfo` attribute to inspect connection properties."""
+ return ConnectionInfo(self.pgconn)
+
+ @property
+ def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ self._adapters = AdaptersMap(postgres.adapters)
+
+ return self._adapters
+
+ @property
+ def connection(self) -> "BaseConnection[Row]":
+ # implement the AdaptContext protocol
+ return self
+
+ def fileno(self) -> int:
+ """Return the file descriptor of the connection.
+
+ This function allows to use the connection as file-like object in
+ functions waiting for readiness, such as the ones defined in the
+ `selectors` module.
+ """
+ return self.pgconn.socket
+
+ def cancel(self) -> None:
+ """Cancel the current operation on the connection."""
+ # No-op if the connection is closed
+ # this allows to use the method as callback handler without caring
+ # about its life.
+ if self.closed:
+ return
+
+ if self._tpc and self._tpc[1]:
+ raise e.ProgrammingError(
+ "cancel() cannot be used with a prepared two-phase transaction"
+ )
+
+ c = self.pgconn.get_cancel()
+ c.cancel()
+
+ def add_notice_handler(self, callback: NoticeHandler) -> None:
+ """
+ Register a callable to be invoked when a notice message is received.
+
+ :param callback: the callback to call upon message received.
+ :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+ """
+ self._notice_handlers.append(callback)
+
+ def remove_notice_handler(self, callback: NoticeHandler) -> None:
+ """
+ Unregister a notice message callable previously registered.
+
+ :param callback: the callback to remove.
+ :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+ """
+ self._notice_handlers.remove(callback)
+
+ @staticmethod
+ def _notice_handler(
+ wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult"
+ ) -> None:
+ self = wself()
+ if not (self and self._notice_handlers):
+ return
+
+ diag = e.Diagnostic(res, pgconn_encoding(self.pgconn))
+ for cb in self._notice_handlers:
+ try:
+ cb(diag)
+ except Exception as ex:
+ logger.exception("error processing notice callback '%s': %s", cb, ex)
+
+ def add_notify_handler(self, callback: NotifyHandler) -> None:
+ """
+ Register a callable to be invoked whenever a notification is received.
+
+ :param callback: the callback to call upon notification received.
+ :type callback: Callable[[~psycopg.Notify], None]
+ """
+ self._notify_handlers.append(callback)
+
+ def remove_notify_handler(self, callback: NotifyHandler) -> None:
+ """
+ Unregister a notification callable previously registered.
+
+ :param callback: the callback to remove.
+ :type callback: Callable[[~psycopg.Notify], None]
+ """
+ self._notify_handlers.remove(callback)
+
+ @staticmethod
+ def _notify_handler(
+ wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
+ ) -> None:
+ self = wself()
+ if not (self and self._notify_handlers):
+ return
+
+ enc = pgconn_encoding(self.pgconn)
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ for cb in self._notify_handlers:
+ cb(n)
+
+ @property
+ def prepare_threshold(self) -> Optional[int]:
+ """
+ Number of times a query is executed before it is prepared.
+
+ - If it is set to 0, every query is prepared the first time it is
+ executed.
+ - If it is set to `!None`, prepared statements are disabled on the
+ connection.
+
+ Default value: 5
+ """
+ return self._prepared.prepare_threshold
+
+ @prepare_threshold.setter
+ def prepare_threshold(self, value: Optional[int]) -> None:
+ self._prepared.prepare_threshold = value
+
+ @property
+ def prepared_max(self) -> int:
+ """
+ Maximum number of prepared statements on the connection.
+
+ Default value: 100
+ """
+ return self._prepared.prepared_max
+
+ @prepared_max.setter
+ def prepared_max(self, value: int) -> None:
+ self._prepared.prepared_max = value
+
+ # Generators to perform high-level operations on the connection
+ #
+ # These operations are expressed in terms of non-blocking generators
+ # and the task of waiting when needed (when the generators yield) is left
+ # to the connections subclass, which might wait either in blocking mode
+ # or through asyncio.
+ #
+ # All these generators assume exclusive access to the connection: subclasses
+ # should have a lock and hold it before calling and consuming them.
+
+ @classmethod
+ def _connect_gen(
+ cls: Type[ConnectionType],
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ ) -> PQGenConn[ConnectionType]:
+ """Generator to connect to the database and create a new instance."""
+ pgconn = yield from connect(conninfo)
+ conn = cls(pgconn)
+ conn._autocommit = bool(autocommit)
+ return conn
+
+ def _exec_command(
+ self, command: Query, result_format: pq.Format = TEXT
+ ) -> PQGen[Optional["PGresult"]]:
+ """
+ Generator to send a command and receive the result to the backend.
+
+ Only used to implement internal commands such as "commit", with eventual
+ arguments bound client-side. The cursor can do more complex stuff.
+ """
+ self._check_connection_ok()
+
+ if isinstance(command, str):
+ command = command.encode(pgconn_encoding(self.pgconn))
+ elif isinstance(command, Composable):
+ command = command.as_bytes(self)
+
+ if self._pipeline:
+ cmd = partial(
+ self.pgconn.send_query_params,
+ command,
+ None,
+ result_format=result_format,
+ )
+ self._pipeline.command_queue.append(cmd)
+ self._pipeline.result_queue.append(None)
+ return None
+
+ self.pgconn.send_query_params(command, None, result_format=result_format)
+
+ result = (yield from execute(self.pgconn))[-1]
+ if result.status != COMMAND_OK and result.status != TUPLES_OK:
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ else:
+ raise e.InterfaceError(
+ f"unexpected result {pq.ExecStatus(result.status).name}"
+ f" from command {command.decode()!r}"
+ )
+ return result
+
+ def _check_connection_ok(self) -> None:
+ if self.pgconn.status == OK:
+ return
+
+ if self.pgconn.status == BAD:
+ raise e.OperationalError("the connection is closed")
+ raise e.InterfaceError(
+ "cannot execute operations: the connection is"
+ f" in status {self.pgconn.status}"
+ )
+
+ def _start_query(self) -> PQGen[None]:
+ """Generator to start a transaction if necessary."""
+ if self._autocommit:
+ return
+
+ if self.pgconn.transaction_status != IDLE:
+ return
+
+ yield from self._exec_command(self._get_tx_start_command())
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _get_tx_start_command(self) -> bytes:
+ if self._begin_statement:
+ return self._begin_statement
+
+ parts = [b"BEGIN"]
+
+ if self.isolation_level is not None:
+ val = IsolationLevel(self.isolation_level)
+ parts.append(b"ISOLATION LEVEL")
+ parts.append(val.name.replace("_", " ").encode())
+
+ if self.read_only is not None:
+ parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
+
+ if self.deferrable is not None:
+ parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
+
+ self._begin_statement = b" ".join(parts)
+ return self._begin_statement
+
+ def _commit_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.commit()`."""
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ "Explicit commit() forbidden within a Transaction "
+ "context. (Transaction will be automatically committed "
+ "on successful exit from context.)"
+ )
+ if self._tpc:
+ raise e.ProgrammingError(
+ "commit() cannot be used during a two-phase transaction"
+ )
+ if self.pgconn.transaction_status == IDLE:
+ return
+
+ yield from self._exec_command(b"COMMIT")
+
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _rollback_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.rollback()`."""
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ "Explicit rollback() forbidden within a Transaction "
+ "context. (Either raise Rollback() or allow "
+ "an exception to propagate out of the context.)"
+ )
+ if self._tpc:
+ raise e.ProgrammingError(
+ "rollback() cannot be used during a two-phase transaction"
+ )
+
+ # Get out of a "pipeline aborted" state
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ if self.pgconn.transaction_status == IDLE:
+ return
+
+ yield from self._exec_command(b"ROLLBACK")
+ self._prepared.clear()
+ for cmd in self._prepared.get_maintenance_commands():
+ yield from self._exec_command(cmd)
+
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
+ """
+ Returns a `Xid` to pass to the `!tpc_*()` methods of this connection.
+
+ The argument types and constraints are explained in
+ :ref:`two-phase-commit`.
+
+ The values passed to the method will be available on the returned
+ object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`.
+ """
+ self._check_tpc()
+ return Xid.from_parts(format_id, gtrid, bqual)
+
+ def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
+ self._check_tpc()
+
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self.pgconn.transaction_status != IDLE:
+ raise e.ProgrammingError(
+ "can't start two-phase transaction: connection in status"
+ f" {pq.TransactionStatus(self.pgconn.transaction_status).name}"
+ )
+
+ if self._autocommit:
+ raise e.ProgrammingError(
+ "can't use two-phase transactions in autocommit mode"
+ )
+
+ self._tpc = (xid, False)
+ yield from self._exec_command(self._get_tx_start_command())
+
+ def _tpc_prepare_gen(self) -> PQGen[None]:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' must be called inside a two-phase transaction"
+ )
+ if self._tpc[1]:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' cannot be used during a prepared two-phase transaction"
+ )
+ xid = self._tpc[0]
+ self._tpc = (xid, True)
+ yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _tpc_finish_gen(
+ self, action: LiteralString, xid: Union[Xid, str, None]
+ ) -> PQGen[None]:
+ fname = f"tpc_{action.lower()}()"
+ if xid is None:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} without xid must must be"
+ " called inside a two-phase transaction"
+ )
+ xid = self._tpc[0]
+ else:
+ if self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} with xid must must be called"
+ " outside a two-phase transaction"
+ )
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self._tpc and not self._tpc[1]:
+ meth: Callable[[], PQGen[None]]
+ meth = getattr(self, f"_{action.lower()}_gen")
+ self._tpc = None
+ yield from meth()
+ else:
+ yield from self._exec_command(
+ SQL("{} PREPARED {}").format(SQL(action), str(xid))
+ )
+ self._tpc = None
+
+ def _check_tpc(self) -> None:
+ """Raise NotSupportedError if TPC is not supported."""
+ # TPC supported on every supported PostgreSQL version.
+ pass
+
+
+class Connection(BaseConnection[Row]):
+ """
+ Wrapper for a connection to the database.
+ """
+
+ __module__ = "psycopg"
+
+ cursor_factory: Type[Cursor[Row]]
+ server_cursor_factory: Type[ServerCursor[Row]]
+ row_factory: RowFactory[Row]
+ _pipeline: Optional[Pipeline]
+ _Self = TypeVar("_Self", bound="Connection[Any]")
+
+ def __init__(
+ self,
+ pgconn: "PGconn",
+ row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row),
+ ):
+ super().__init__(pgconn)
+ self.row_factory = row_factory
+ self.lock = threading.Lock()
+ self.cursor_factory = Cursor
+ self.server_cursor_factory = ServerCursor
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ row_factory: RowFactory[Row],
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[Cursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "Connection[Row]":
+ # TODO: returned type should be _Self. See #308.
+ ...
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[Cursor[Any]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "Connection[TupleRow]":
+ ...
+
+ @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: Optional[RowFactory[Row]] = None,
+ cursor_factory: Optional[Type[Cursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Any,
+ ) -> "Connection[Any]":
+ """
+ Connect to a database server and return a new `Connection` instance.
+ """
+ params = cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
+ try:
+ rv = cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit),
+ timeout=params["connect_timeout"],
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ if row_factory:
+ rv.row_factory = row_factory
+ if cursor_factory:
+ rv.cursor_factory = cursor_factory
+ if context:
+ rv._adapters = AdaptersMap(context.adapters)
+ rv.prepare_threshold = prepare_threshold
+ return rv
+
+ def __enter__(self: _Self) -> _Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if self.closed:
+ return
+
+ if exc_type:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ self.rollback()
+ except Exception as exc2:
+ logger.warning(
+ "error ignored in rollback on %s: %s",
+ self,
+ exc2,
+ )
+ else:
+ self.commit()
+
+ # Close the connection only if it doesn't belong to a pool.
+ if not getattr(self, "_pool", None):
+ self.close()
+
+ @classmethod
+ def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
+ """Manipulate connection parameters before connecting.
+
+ :param conninfo: Connection string as received by `~Connection.connect()`.
+ :param kwargs: Overriding connection arguments as received by `!connect()`.
+ :return: Connection arguments merged and eventually modified, in a
+ format similar to `~conninfo.conninfo_to_dict()`.
+ """
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ return params
+
+ def close(self) -> None:
+ """Close the database connection."""
+ if self.closed:
+ return
+ self._closed = True
+ self.pgconn.finish()
+
+ @overload
+ def cursor(self, *, binary: bool = False) -> Cursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+ ) -> Cursor[CursorRow]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> ServerCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: RowFactory[CursorRow],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> ServerCursor[CursorRow]:
+ ...
+
+ def cursor(
+ self,
+ name: str = "",
+ *,
+ binary: bool = False,
+ row_factory: Optional[RowFactory[Any]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> Union[Cursor[Any], ServerCursor[Any]]:
+ """
+ Return a new cursor to send commands and queries to the connection.
+ """
+ self._check_connection_ok()
+
+ if not row_factory:
+ row_factory = self.row_factory
+
+ cur: Union[Cursor[Any], ServerCursor[Any]]
+ if name:
+ cur = self.server_cursor_factory(
+ self,
+ name=name,
+ row_factory=row_factory,
+ scrollable=scrollable,
+ withhold=withhold,
+ )
+ else:
+ cur = self.cursor_factory(self, row_factory=row_factory)
+
+ if binary:
+ cur.format = BINARY
+
+ return cur
+
+ def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: bool = False,
+ ) -> Cursor[Row]:
+ """Execute a query and return a cursor to read its results."""
+ try:
+ cur = self.cursor()
+ if binary:
+ cur.format = BINARY
+
+ return cur.execute(query, params, prepare=prepare)
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def commit(self) -> None:
+ """Commit any pending transaction to the database."""
+ with self.lock:
+ self.wait(self._commit_gen())
+
+ def rollback(self) -> None:
+ """Roll back to the start of any pending transaction."""
+ with self.lock:
+ self.wait(self._rollback_gen())
+
+ @contextmanager
+ def transaction(
+ self,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ) -> Iterator[Transaction]:
+ """
+ Start a context block with a new transaction or nested transaction.
+
+ :param savepoint_name: Name of the savepoint used to manage a nested
+ transaction. If `!None`, one will be chosen automatically.
+ :param force_rollback: Roll back the transaction at the end of the
+ block even if there were no error (e.g. to try a no-op process).
+ :rtype: Transaction
+ """
+ tx = Transaction(self, savepoint_name, force_rollback)
+ if self._pipeline:
+ with self.pipeline(), tx, self.pipeline():
+ yield tx
+ else:
+ with tx:
+ yield tx
+
+ def notifies(self) -> Generator[Notify, None, None]:
+ """
+ Yield `Notify` objects as soon as they are received from the database.
+ """
+ while True:
+ with self.lock:
+ try:
+ ns = self.wait(notifies(self.pgconn))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ enc = pgconn_encoding(self.pgconn)
+ for pgn in ns:
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ yield n
+
+ @contextmanager
+ def pipeline(self) -> Iterator[Pipeline]:
+ """Switch the connection into pipeline mode."""
+ with self.lock:
+ self._check_connection_ok()
+
+ pipeline = self._pipeline
+ if pipeline is None:
+ # WARNING: reference loop, broken ahead.
+ pipeline = self._pipeline = Pipeline(self)
+
+ try:
+ with pipeline:
+ yield pipeline
+ finally:
+ if pipeline.level == 0:
+ with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
+
+ def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+ """
+ Consume a generator operating on the connection.
+
+ The function must be used on generators that don't change connection
+ fd (i.e. not on connect and reset).
+ """
+ try:
+ return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+ except KeyboardInterrupt:
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # the connection will remain stuck in ACTIVE state.
+ c = self.pgconn.get_cancel()
+ c.cancel()
+ try:
+ waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+ except e.QueryCanceled:
+ pass # as expected
+ raise
+
+ @classmethod
+ def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
+ """Consume a connection generator."""
+ return waiting.wait_conn(gen, timeout=timeout)
+
+ def _set_autocommit(self, value: bool) -> None:
+ with self.lock:
+ self.wait(self._set_autocommit_gen(value))
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ with self.lock:
+ self.wait(self._set_isolation_level_gen(value))
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ with self.lock:
+ self.wait(self._set_read_only_gen(value))
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ with self.lock:
+ self.wait(self._set_deferrable_gen(value))
+
+ def tpc_begin(self, xid: Union[Xid, str]) -> None:
+ """
+ Begin a TPC transaction with the given transaction ID `!xid`.
+ """
+ with self.lock:
+ self.wait(self._tpc_begin_gen(xid))
+
+ def tpc_prepare(self) -> None:
+ """
+ Perform the first phase of a transaction started with `tpc_begin()`.
+ """
+ try:
+ with self.lock:
+ self.wait(self._tpc_prepare_gen())
+ except e.ObjectNotInPrerequisiteState as ex:
+ raise e.NotSupportedError(str(ex)) from None
+
+ def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+ """
+ Commit a prepared two-phase transaction.
+ """
+ with self.lock:
+ self.wait(self._tpc_finish_gen("COMMIT", xid))
+
+ def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+ """
+ Roll back a prepared two-phase transaction.
+ """
+ with self.lock:
+ self.wait(self._tpc_finish_gen("ROLLBACK", xid))
+
+ def tpc_recover(self) -> List[Xid]:
+ self._check_tpc()
+ status = self.info.transaction_status
+ with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+ cur.execute(Xid._get_recover_query())
+ res = cur.fetchall()
+
+ if status == IDLE and self.info.transaction_status == INTRANS:
+ self.rollback()
+
+ return res
diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py
new file mode 100644
index 0000000..aa02dc0
--- /dev/null
+++ b/psycopg/psycopg/connection_async.py
@@ -0,0 +1,436 @@
+"""
+psycopg async connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import asyncio
+import logging
+from types import TracebackType
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
+from contextlib import asynccontextmanager
+
+from . import pq
+from . import errors as e
+from . import waiting
+from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from ._tpc import Xid
+from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
+from .adapt import AdaptersMap
+from ._enums import IsolationLevel
+from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
+from ._pipeline import AsyncPipeline
+from ._encodings import pgconn_encoding
+from .connection import BaseConnection, CursorRow, Notify
+from .generators import notifies
+from .transaction import AsyncTransaction
+from .cursor_async import AsyncCursor
+from .server_cursor import AsyncServerCursor
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+logger = logging.getLogger("psycopg")
+
+
+class AsyncConnection(BaseConnection[Row]):
+ """
+ Asynchronous wrapper for a connection to the database.
+ """
+
+ __module__ = "psycopg"
+
+ cursor_factory: Type[AsyncCursor[Row]]
+ server_cursor_factory: Type[AsyncServerCursor[Row]]
+ row_factory: AsyncRowFactory[Row]
+ _pipeline: Optional[AsyncPipeline]
+ _Self = TypeVar("_Self", bound="AsyncConnection[Any]")
+
+ def __init__(
+ self,
+ pgconn: "PGconn",
+ row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
+ ):
+ super().__init__(pgconn)
+ self.row_factory = row_factory
+ self.lock = asyncio.Lock()
+ self.cursor_factory = AsyncCursor
+ self.server_cursor_factory = AsyncServerCursor
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: AsyncRowFactory[Row],
+ cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncConnection[Row]":
+ # TODO: returned type should be _Self. See #308.
+ ...
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncConnection[TupleRow]":
+ ...
+
+ @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ context: Optional[AdaptContext] = None,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
+ **kwargs: Any,
+ ) -> "AsyncConnection[Any]":
+
+ if sys.platform == "win32":
+ loop = asyncio.get_running_loop()
+ if isinstance(loop, asyncio.ProactorEventLoop):
+ raise e.InterfaceError(
+ "Psycopg cannot use the 'ProactorEventLoop' to run in async"
+ " mode. Please use a compatible event loop, for instance by"
+ " setting 'asyncio.set_event_loop_policy"
+ "(WindowsSelectorEventLoopPolicy())'"
+ )
+
+ params = await cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
+ try:
+ rv = await cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit),
+ timeout=params["connect_timeout"],
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ if row_factory:
+ rv.row_factory = row_factory
+ if cursor_factory:
+ rv.cursor_factory = cursor_factory
+ if context:
+ rv._adapters = AdaptersMap(context.adapters)
+ rv.prepare_threshold = prepare_threshold
+ return rv
+
+ async def __aenter__(self: _Self) -> _Self:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if self.closed:
+ return
+
+ if exc_type:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ await self.rollback()
+ except Exception as exc2:
+ logger.warning(
+ "error ignored in rollback on %s: %s",
+ self,
+ exc2,
+ )
+ else:
+ await self.commit()
+
+ # Close the connection only if it doesn't belong to a pool.
+ if not getattr(self, "_pool", None):
+ await self.close()
+
+ @classmethod
+ async def _get_connection_params(
+ cls, conninfo: str, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Manipulate connection parameters before connecting.
+
+ .. versionchanged:: 3.1
+ Unlike the sync counterpart, perform non-blocking address
+ resolution and populate the ``hostaddr`` connection parameter,
+ unless the user has provided one themselves. See
+ `~psycopg._dns.resolve_hostaddr_async()` for details.
+
+ """
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ # Resolve host addresses in non-blocking way
+ params = await resolve_hostaddr_async(params)
+
+ return params
+
+ async def close(self) -> None:
+ if self.closed:
+ return
+ self._closed = True
+ self.pgconn.finish()
+
+ @overload
+ def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
+ ) -> AsyncCursor[CursorRow]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> AsyncServerCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: AsyncRowFactory[CursorRow],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> AsyncServerCursor[CursorRow]:
+ ...
+
+ def cursor(
+ self,
+ name: str = "",
+ *,
+ binary: bool = False,
+ row_factory: Optional[AsyncRowFactory[Any]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
+ """
+ Return a new `AsyncCursor` to send commands and queries to the connection.
+ """
+ self._check_connection_ok()
+
+ if not row_factory:
+ row_factory = self.row_factory
+
+ cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
+ if name:
+ cur = self.server_cursor_factory(
+ self,
+ name=name,
+ row_factory=row_factory,
+ scrollable=scrollable,
+ withhold=withhold,
+ )
+ else:
+ cur = self.cursor_factory(self, row_factory=row_factory)
+
+ if binary:
+ cur.format = BINARY
+
+ return cur
+
+ async def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: bool = False,
+ ) -> AsyncCursor[Row]:
+ try:
+ cur = self.cursor()
+ if binary:
+ cur.format = BINARY
+
+ return await cur.execute(query, params, prepare=prepare)
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def commit(self) -> None:
+ async with self.lock:
+ await self.wait(self._commit_gen())
+
+ async def rollback(self) -> None:
+ async with self.lock:
+ await self.wait(self._rollback_gen())
+
+ @asynccontextmanager
+ async def transaction(
+ self,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ) -> AsyncIterator[AsyncTransaction]:
+ """
+ Start a context block with a new transaction or nested transaction.
+
+ :rtype: AsyncTransaction
+ """
+ tx = AsyncTransaction(self, savepoint_name, force_rollback)
+ if self._pipeline:
+ async with self.pipeline(), tx, self.pipeline():
+ yield tx
+ else:
+ async with tx:
+ yield tx
+
+ async def notifies(self) -> AsyncGenerator[Notify, None]:
+ while True:
+ async with self.lock:
+ try:
+ ns = await self.wait(notifies(self.pgconn))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ enc = pgconn_encoding(self.pgconn)
+ for pgn in ns:
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ yield n
+
+ @asynccontextmanager
+ async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
+ """Context manager to switch the connection into pipeline mode."""
+ async with self.lock:
+ self._check_connection_ok()
+
+ pipeline = self._pipeline
+ if pipeline is None:
+ # WARNING: reference loop, broken ahead.
+ pipeline = self._pipeline = AsyncPipeline(self)
+
+ try:
+ async with pipeline:
+ yield pipeline
+ finally:
+ if pipeline.level == 0:
+ async with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
+
+ async def wait(self, gen: PQGen[RV]) -> RV:
+ try:
+ return await waiting.wait_async(gen, self.pgconn.socket)
+ except KeyboardInterrupt:
+ # TODO: this doesn't seem to work as it does for sync connections
+ # see tests/test_concurrency_async.py::test_ctrl_c
+ # In the test, the code doesn't reach this branch.
+
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # otherwise the connection will be stuck in ACTIVE state
+ c = self.pgconn.get_cancel()
+ c.cancel()
+ try:
+ await waiting.wait_async(gen, self.pgconn.socket)
+ except e.QueryCanceled:
+ pass # as expected
+ raise
+
+ @classmethod
+ async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
+ return await waiting.wait_conn_async(gen, timeout)
+
+ def _set_autocommit(self, value: bool) -> None:
+ self._no_set_async("autocommit")
+
+ async def set_autocommit(self, value: bool) -> None:
+ """Async version of the `~Connection.autocommit` setter."""
+ async with self.lock:
+ await self.wait(self._set_autocommit_gen(value))
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ self._no_set_async("isolation_level")
+
+ async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ """Async version of the `~Connection.isolation_level` setter."""
+ async with self.lock:
+ await self.wait(self._set_isolation_level_gen(value))
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ self._no_set_async("read_only")
+
+ async def set_read_only(self, value: Optional[bool]) -> None:
+ """Async version of the `~Connection.read_only` setter."""
+ async with self.lock:
+ await self.wait(self._set_read_only_gen(value))
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ self._no_set_async("deferrable")
+
+ async def set_deferrable(self, value: Optional[bool]) -> None:
+ """Async version of the `~Connection.deferrable` setter."""
+ async with self.lock:
+ await self.wait(self._set_deferrable_gen(value))
+
+ def _no_set_async(self, attribute: str) -> None:
+ raise AttributeError(
+ f"'the {attribute!r} property is read-only on async connections:"
+ f" please use 'await .set_{attribute}()' instead."
+ )
+
+ async def tpc_begin(self, xid: Union[Xid, str]) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_begin_gen(xid))
+
+ async def tpc_prepare(self) -> None:
+ try:
+ async with self.lock:
+ await self.wait(self._tpc_prepare_gen())
+ except e.ObjectNotInPrerequisiteState as ex:
+ raise e.NotSupportedError(str(ex)) from None
+
+ async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_finish_gen("commit", xid))
+
+ async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_finish_gen("rollback", xid))
+
+ async def tpc_recover(self) -> List[Xid]:
+ self._check_tpc()
+ status = self.info.transaction_status
+ async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+ await cur.execute(Xid._get_recover_query())
+ res = await cur.fetchall()
+
+ if status == IDLE and self.info.transaction_status == INTRANS:
+ await self.rollback()
+
+ return res
diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py
new file mode 100644
index 0000000..3b21f83
--- /dev/null
+++ b/psycopg/psycopg/conninfo.py
@@ -0,0 +1,378 @@
+"""
+Functions to manipulate conninfo strings
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+import socket
+import asyncio
+from typing import Any, Dict, List, Optional
+from pathlib import Path
+from datetime import tzinfo
+from functools import lru_cache
+from ipaddress import ip_address
+
+from . import pq
+from . import errors as e
+from ._tz import get_tzinfo
+from ._encodings import pgconn_encoding
+
+
+def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
+ """
+ Merge a string and keyword params into a single conninfo string.
+
+ :param conninfo: A `connection string`__ as accepted by PostgreSQL.
+ :param kwargs: Parameters overriding the ones specified in `!conninfo`.
+ :return: A connection string valid for PostgreSQL, with the `!kwargs`
+ parameters merged.
+
+ Raise `~psycopg.ProgrammingError` if the input doesn't make a valid
+ conninfo string.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-CONNSTRING
+ """
+ if not conninfo and not kwargs:
+ return ""
+
+ # If no kwarg specified don't mung the conninfo but check if it's correct.
+ # Make sure to return a string, not a subtype, to avoid making Liskov sad.
+ if not kwargs:
+ _parse_conninfo(conninfo)
+ return str(conninfo)
+
+ # Override the conninfo with the parameters
+ # Drop the None arguments
+ kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
+
+ if conninfo:
+ tmp = conninfo_to_dict(conninfo)
+ tmp.update(kwargs)
+ kwargs = tmp
+
+ conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
+
+ # Verify the result is valid
+ _parse_conninfo(conninfo)
+
+ return conninfo
+
+
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
+ """
+ Convert the `!conninfo` string into a dictionary of parameters.
+
+ :param conninfo: A `connection string`__ as accepted by PostgreSQL.
+ :param kwargs: Parameters overriding the ones specified in `!conninfo`.
+ :return: Dictionary with the parameters parsed from `!conninfo` and
+ `!kwargs`.
+
+ Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection
+ string.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-CONNSTRING
+ """
+ opts = _parse_conninfo(conninfo)
+ rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
+ for k, v in kwargs.items():
+ if v is not None:
+ rv[k] = v
+ return rv
+
+
+def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
+ """
+ Verify that `!conninfo` is a valid connection string.
+
+ Raise ProgrammingError if the string is not valid.
+
+ Return the result of pq.Conninfo.parse() on success.
+ """
+ try:
+ return pq.Conninfo.parse(conninfo.encode())
+ except e.OperationalError as ex:
+ raise e.ProgrammingError(str(ex))
+
+
+re_escape = re.compile(r"([\\'])")
+re_space = re.compile(r"\s")
+
+
+def _param_escape(s: str) -> str:
+ """
+ Apply the escaping rule required by PQconnectdb
+ """
+ if not s:
+ return "''"
+
+ s = re_escape.sub(r"\\\1", s)
+ if re_space.search(s):
+ s = "'" + s + "'"
+
+ return s
+
+
+class ConnectionInfo:
+ """Allow access to information about the connection."""
+
+ __module__ = "psycopg"
+
+ def __init__(self, pgconn: pq.abc.PGconn):
+ self.pgconn = pgconn
+
+ @property
+ def vendor(self) -> str:
+ """A string representing the database vendor connected to."""
+ return "PostgreSQL"
+
+ @property
+ def host(self) -> str:
+ """The server host name of the active connection. See :pq:`PQhost()`."""
+ return self._get_pgconn_attr("host")
+
+ @property
+ def hostaddr(self) -> str:
+ """The server IP address of the connection. See :pq:`PQhostaddr()`."""
+ return self._get_pgconn_attr("hostaddr")
+
+ @property
+ def port(self) -> int:
+ """The port of the active connection. See :pq:`PQport()`."""
+ return int(self._get_pgconn_attr("port"))
+
+ @property
+ def dbname(self) -> str:
+ """The database name of the connection. See :pq:`PQdb()`."""
+ return self._get_pgconn_attr("db")
+
+ @property
+ def user(self) -> str:
+ """The user name of the connection. See :pq:`PQuser()`."""
+ return self._get_pgconn_attr("user")
+
+ @property
+ def password(self) -> str:
+ """The password of the connection. See :pq:`PQpass()`."""
+ return self._get_pgconn_attr("password")
+
+ @property
+ def options(self) -> str:
+ """
+ The command-line options passed in the connection request.
+ See :pq:`PQoptions`.
+ """
+ return self._get_pgconn_attr("options")
+
+ def get_parameters(self) -> Dict[str, str]:
+ """Return the connection parameters values.
+
+ Return all the parameters set to a non-default value, which might come
+ either from the connection string and parameters passed to
+ `~Connection.connect()` or from environment variables. The password
+ is never returned (you can read it using the `password` attribute).
+ """
+ pyenc = self.encoding
+
+ # Get the known defaults to avoid reporting them
+ defaults = {
+ i.keyword: i.compiled
+ for i in pq.Conninfo.get_defaults()
+ if i.compiled is not None
+ }
+ # Not returned by the libq. Bug? Bet we're using SSH.
+ defaults.setdefault(b"channel_binding", b"prefer")
+ defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
+
+ return {
+ i.keyword.decode(pyenc): i.val.decode(pyenc)
+ for i in self.pgconn.info
+ if i.val is not None
+ and i.keyword != b"password"
+ and i.val != defaults.get(i.keyword)
+ }
+
+ @property
+ def dsn(self) -> str:
+ """Return the connection string to connect to the database.
+
+ The string contains all the parameters set to a non-default value,
+ which might come either from the connection string and parameters
+ passed to `~Connection.connect()` or from environment variables. The
+ password is never returned (you can read it using the `password`
+ attribute).
+ """
+ return make_conninfo(**self.get_parameters())
+
+ @property
+ def status(self) -> pq.ConnStatus:
+ """The status of the connection. See :pq:`PQstatus()`."""
+ return pq.ConnStatus(self.pgconn.status)
+
+ @property
+ def transaction_status(self) -> pq.TransactionStatus:
+ """
+ The current in-transaction status of the session.
+ See :pq:`PQtransactionStatus()`.
+ """
+ return pq.TransactionStatus(self.pgconn.transaction_status)
+
+ @property
+ def pipeline_status(self) -> pq.PipelineStatus:
+ """
+ The current pipeline status of the client.
+ See :pq:`PQpipelineStatus()`.
+ """
+ return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+ def parameter_status(self, param_name: str) -> Optional[str]:
+ """
+ Return a parameter setting of the connection.
+
+ Return `None` is the parameter is unknown.
+ """
+ res = self.pgconn.parameter_status(param_name.encode(self.encoding))
+ return res.decode(self.encoding) if res is not None else None
+
+ @property
+ def server_version(self) -> int:
+ """
+ An integer representing the server version. See :pq:`PQserverVersion()`.
+ """
+ return self.pgconn.server_version
+
+ @property
+ def backend_pid(self) -> int:
+ """
+ The process ID (PID) of the backend process handling this connection.
+ See :pq:`PQbackendPID()`.
+ """
+ return self.pgconn.backend_pid
+
+ @property
+ def error_message(self) -> str:
+ """
+ The error message most recently generated by an operation on the connection.
+ See :pq:`PQerrorMessage()`.
+ """
+ return self._get_pgconn_attr("error_message")
+
+ @property
+ def timezone(self) -> tzinfo:
+ """The Python timezone info of the connection's timezone."""
+ return get_tzinfo(self.pgconn)
+
+ @property
+ def encoding(self) -> str:
+ """The Python codec name of the connection's client encoding."""
+ return pgconn_encoding(self.pgconn)
+
+ def _get_pgconn_attr(self, name: str) -> str:
+ value: bytes = getattr(self.pgconn, name)
+ return value.decode(self.encoding)
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`.
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses dynamically.
+
+ The function may change the input ``host``, ``hostname``, ``port`` to allow
+ connecting without further DNS lookups, eventually removing hosts that are
+ not resolved, keeping the lists of hosts and ports consistent.
+
+ Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
+ host resolve, inconsistent lists length).
+ """
+ hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
+ if hostaddr_arg:
+ # Already resolved
+ return params
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ if not host_arg:
+ # Nothing to resolve
+ return params
+
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",") if port_arg else []
+ default_port = "5432"
+
+ if len(ports_in) == 1:
+ # If only one port is specified, the libpq will apply it to all
+ # the hosts, so don't mangle it.
+ default_port = ports_in.pop()
+
+ elif len(ports_in) > 1:
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+ )
+ ports_out = []
+
+ hosts_out = []
+ hostaddr_out = []
+ loop = asyncio.get_running_loop()
+ for i, host in enumerate(hosts_in):
+ if not host or host.startswith("/") or host[1:2] == ":":
+ # Local path
+ hosts_out.append(host)
+ hostaddr_out.append("")
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ # If the host is already an ip address don't try to resolve it
+ if is_ip_address(host):
+ hosts_out.append(host)
+ hostaddr_out.append(host)
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ try:
+ port = ports_in[i] if ports_in else default_port
+ ans = await loop.getaddrinfo(
+ host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ )
+ except OSError as ex:
+ last_exc = ex
+ else:
+ for item in ans:
+ hosts_out.append(host)
+ hostaddr_out.append(item[4][0])
+ if ports_in:
+ ports_out.append(ports_in[i])
+
+ # Throw an exception if no host could be resolved
+ if not hosts_out:
+ raise e.OperationalError(str(last_exc))
+
+ out = params.copy()
+ out["host"] = ",".join(hosts_out)
+ out["hostaddr"] = ",".join(hostaddr_out)
+ if ports_in:
+ out["port"] = ",".join(ports_out)
+
+ return out
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+ """Return True if the string represent a valid ip address."""
+ try:
+ ip_address(s)
+ except ValueError:
+ return False
+ return True
diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py
new file mode 100644
index 0000000..7514306
--- /dev/null
+++ b/psycopg/psycopg/copy.py
@@ -0,0 +1,904 @@
+"""
+psycopg copy support
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import queue
+import struct
+import asyncio
+import threading
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
+from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import Buffer, ConnectionType, PQGen, Transformer
+from ._compat import create_task
+from ._cmodule import _psycopg
+from ._encodings import pgconn_encoding
+from .generators import copy_from, copy_to, copy_end
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor, Cursor
+ from .cursor_async import AsyncCursor
+ from .connection import Connection # noqa: F401
+ from .connection_async import AsyncConnection # noqa: F401
+
+PY_TEXT = adapt.PyFormat.TEXT
+PY_BINARY = adapt.PyFormat.BINARY
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+# Size of data to accumulate before sending it down the network. We fill a
+# buffer this size field by field, and when it passes the threshold size
+# we ship it, so it may end up being bigger than this.
+BUFFER_SIZE = 32 * 1024
+
+# Maximum data size we want to queue to send to the libpq copy. Sending a
+# buffer too big to be handled can cause an infinite loop in the libpq
+# (#255) so we want to split it in more digestable chunks.
+MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
+# Note: making this buffer too large, e.g.
+# MAX_BUFFER_SIZE = 1024 * 1024
+# makes operations *way* slower! Probably triggering some quadraticity
+# in the libpq memory management and data sending.
+
+# Max size of the write queue of buffers. More than that copy will block
+# Each buffer should be around BUFFER_SIZE size.
+QUEUE_SIZE = 1024
+
+
+class BaseCopy(Generic[ConnectionType]):
+ """
+ Base implementation for the copy user interface.
+
+ Two subclasses expose real methods with the sync/async differences.
+
+ The difference between the text and binary format is managed by two
+ different `Formatter` subclasses.
+
+ Writing (the I/O part) is implemented in the subclasses by a `Writer` or
+ `AsyncWriter` instance. Normally writing implies sending copy data to a
+ database, but a different writer might be chosen, e.g. to stream data into
+ a file for later use.
+ """
+
+ _Self = TypeVar("_Self", bound="BaseCopy[Any]")
+
+ formatter: "Formatter"
+
+ def __init__(
+ self,
+ cursor: "BaseCursor[ConnectionType, Any]",
+ *,
+ binary: Optional[bool] = None,
+ ):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ result = cursor.pgresult
+ if result:
+ self._direction = result.status
+ if self._direction != COPY_IN and self._direction != COPY_OUT:
+ raise e.ProgrammingError(
+ "the cursor should have performed a COPY operation;"
+ f" its status is {pq.ExecStatus(self._direction).name} instead"
+ )
+ else:
+ self._direction = COPY_IN
+
+ if binary is None:
+ binary = bool(result and result.binary_tuples)
+
+ tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
+ if binary:
+ self.formatter = BinaryFormatter(tx)
+ else:
+ self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
+
+ self._finished = False
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ def _enter(self) -> None:
+ if self._finished:
+ raise TypeError("copy blocks can be used only once")
+
+ def set_types(self, types: Sequence[Union[int, str]]) -> None:
+ """
+ Set the types expected in a COPY operation.
+
+ The types must be specified as a sequence of oid or PostgreSQL type
+ names (e.g. ``int4``, ``timestamptz[]``).
+
+ This operation overcomes the lack of metadata returned by PostgreSQL
+ when a COPY operation begins:
+
+ - On :sql:`COPY TO`, `!set_types()` allows to specify what types the
+ operation returns. If `!set_types()` is not used, the data will be
+ returned as unparsed strings or bytes instead of Python objects.
+
+ - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
+ database expects. This is especially useful in binary copy, because
+ PostgreSQL will apply no cast rule.
+
+ """
+ registry = self.cursor.adapters.types
+ oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
+
+ if self._direction == COPY_IN:
+ self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
+ else:
+ self.formatter.transformer.set_loader_types(oids, self.formatter.format)
+
+ # High level copy protocol generators (state change of the Copy object)
+
+ def _read_gen(self) -> PQGen[Buffer]:
+ if self._finished:
+ return memoryview(b"")
+
+ res = yield from copy_from(self._pgconn)
+ if isinstance(res, memoryview):
+ return res
+
+ # res is the final PGresult
+ self._finished = True
+
+ # This result is a COMMAND_OK which has info about the number of rows
+ # returned, but not about the columns, which is instead an information
+ # that was received on the COPY_OUT result at the beginning of COPY.
+ # So, don't replace the results in the cursor, just update the rowcount.
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
+ return memoryview(b"")
+
+ def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
+ data = yield from self._read_gen()
+ if not data:
+ return None
+
+ row = self.formatter.parse_row(data)
+ if row is None:
+ # Get the final result to finish the copy operation
+ yield from self._read_gen()
+ self._finished = True
+ return None
+
+ return row
+
+ def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+ if not exc:
+ return
+
+ if self._pgconn.transaction_status != ACTIVE:
+ # The server has already finished to send copy data. The connection
+ # is already in a good state.
+ return
+
+ # Throw a cancel to the server, then consume the rest of the copy data
+ # (which might or might not have been already transferred entirely to
+ # the client, so we won't necessary see the exception associated with
+ # canceling).
+ self.connection.cancel()
+ try:
+ while (yield from self._read_gen()):
+ pass
+ except e.QueryCanceled:
+ pass
+
+
+class Copy(BaseCopy["Connection[Any]"]):
+ """Manage a :sql:`COPY` operation.
+
+ :param cursor: the cursor where the operation is performed.
+ :param binary: if `!True`, write binary format.
+ :param writer: the object to write to destination. If not specified, write
+ to the `!cursor` connection.
+
+ Choosing `!binary` is not necessary if the cursor has executed a
+ :sql:`COPY` operation, because the operation result describes the format
+ too. The parameter is useful when a `!Copy` object is created manually and
+ no operation is performed on the cursor, such as when using ``writer=``\\
+ `~psycopg.copy.FileWriter`.
+
+ """
+
+ __module__ = "psycopg"
+
+ writer: "Writer"
+
+ def __init__(
+ self,
+ cursor: "Cursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["Writer"] = None,
+ ):
+ super().__init__(cursor, binary=binary)
+ if not writer:
+ writer = LibpqWriter(cursor)
+
+ self.writer = writer
+ self._write = writer.write
+
+ def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
+ self._enter()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.finish(exc_val)
+
+ # End user sync interface
+
+ def __iter__(self) -> Iterator[Buffer]:
+ """Implement block-by-block iteration on :sql:`COPY TO`."""
+ while True:
+ data = self.read()
+ if not data:
+ break
+ yield data
+
+ def read(self) -> Buffer:
+ """
+ Read an unparsed row after a :sql:`COPY TO` operation.
+
+ Return an empty string when the data is finished.
+ """
+ return self.connection.wait(self._read_gen())
+
+ def rows(self) -> Iterator[Tuple[Any, ...]]:
+ """
+ Iterate on the result of a :sql:`COPY TO` operation record by record.
+
+ Note that the records returned will be tuples of unparsed strings or
+ bytes, unless data types are specified using `set_types()`.
+ """
+ while True:
+ record = self.read_row()
+ if record is None:
+ break
+ yield record
+
+ def read_row(self) -> Optional[Tuple[Any, ...]]:
+ """
+ Read a parsed row of data from a table after a :sql:`COPY TO` operation.
+
+ Return `!None` when the data is finished.
+
+ Note that the records returned will be tuples of unparsed strings or
+ bytes, unless data types are specified using `set_types()`.
+ """
+ return self.connection.wait(self._read_row_gen())
+
+ def write(self, buffer: Union[Buffer, str]) -> None:
+ """
+ Write a block of data to a table after a :sql:`COPY FROM` operation.
+
+ If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
+ text mode it can be either `!bytes` or `!str`.
+ """
+ data = self.formatter.write(buffer)
+ if data:
+ self._write(data)
+
+ def write_row(self, row: Sequence[Any]) -> None:
+ """Write a record to a table after a :sql:`COPY FROM` operation."""
+ data = self.formatter.write_row(row)
+ if data:
+ self._write(data)
+
+ def finish(self, exc: Optional[BaseException]) -> None:
+ """Terminate the copy operation and free the resources allocated.
+
+ You shouldn't need to call this function yourself: it is usually called
+ by exit. It is available if, despite what is documented, you end up
+ using the `Copy` object outside a block.
+ """
+ if self._direction == COPY_IN:
+ data = self.formatter.end()
+ if data:
+ self._write(data)
+ self.writer.finish(exc)
+ self._finished = True
+ else:
+ self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class Writer(ABC):
+ """
+ A class to write copy data somewhere.
+ """
+
+ @abstractmethod
+ def write(self, data: Buffer) -> None:
+ """
+ Write some data to destination.
+ """
+ ...
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ """
+ Called when write operations are finished.
+
+ If operations finished with an error, it will be passed to ``exc``.
+ """
+ pass
+
+
+class LibpqWriter(Writer):
+ """
+ A `Writer` to write copy data to a Postgres database.
+ """
+
+ def __init__(self, cursor: "Cursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = self.connection.wait(copy_end(self._pgconn, bmsg))
+ self.cursor._results = [res]
+
+
+class QueuedLibpqDriver(LibpqWriter):
+ """
+ A writer using a buffer to queue data to write to a Postgres database.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, cursor: "Cursor[Any]"):
+ super().__init__(cursor)
+
+ self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[threading.Thread] = None
+ self._worker_error: Optional[BaseException] = None
+
+ def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value, or in case
+ of error.
+
+ The function is designed to be run in a separate thread.
+ """
+ try:
+ while True:
+ data = self._queue.get(block=True, timeout=24 * 60 * 60)
+ if not data:
+ break
+ self.connection.wait(copy_to(self._pgconn, data))
+ except BaseException as ex:
+ # Propagate the error to the main thread.
+ self._worker_error = ex
+
+ def write(self, data: Buffer) -> None:
+ if not self._worker:
+ # warning: reference loop, broken by _write_end
+ self._worker = threading.Thread(target=self.worker)
+ self._worker.daemon = True
+ self._worker.start()
+
+ # If the worker thread raies an exception, re-raise it to the caller.
+ if self._worker_error:
+ raise self._worker_error
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ self._queue.put(b"")
+
+ if self._worker:
+ self._worker.join()
+ self._worker = None # break the loop
+
+ # Check if the worker thread raised any exception before terminating.
+ if self._worker_error:
+ raise self._worker_error
+
+ super().finish(exc)
+
+
+class FileWriter(Writer):
+ """
+ A `Writer` to write copy data to a file-like object.
+
+ :param file: the file where to write copy data. It must be open for writing
+ in binary mode.
+ """
+
+ def __init__(self, file: IO[bytes]):
+ self.file = file
+
+ def write(self, data: Buffer) -> None:
+ self.file.write(data) # type: ignore[arg-type]
+
+
+class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
+ """Manage an asynchronous :sql:`COPY` operation."""
+
+ __module__ = "psycopg"
+
+ writer: "AsyncWriter"
+
+ def __init__(
+ self,
+ cursor: "AsyncCursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["AsyncWriter"] = None,
+ ):
+ super().__init__(cursor, binary=binary)
+
+ if not writer:
+ writer = AsyncLibpqWriter(cursor)
+
+ self.writer = writer
+ self._write = writer.write
+
+ async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
+ self._enter()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.finish(exc_val)
+
+ async def __aiter__(self) -> AsyncIterator[Buffer]:
+ while True:
+ data = await self.read()
+ if not data:
+ break
+ yield data
+
+ async def read(self) -> Buffer:
+ return await self.connection.wait(self._read_gen())
+
+ async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
+ while True:
+ record = await self.read_row()
+ if record is None:
+ break
+ yield record
+
+ async def read_row(self) -> Optional[Tuple[Any, ...]]:
+ return await self.connection.wait(self._read_row_gen())
+
+ async def write(self, buffer: Union[Buffer, str]) -> None:
+ data = self.formatter.write(buffer)
+ if data:
+ await self._write(data)
+
+ async def write_row(self, row: Sequence[Any]) -> None:
+ data = self.formatter.write_row(row)
+ if data:
+ await self._write(data)
+
+ async def finish(self, exc: Optional[BaseException]) -> None:
+ if self._direction == COPY_IN:
+ data = self.formatter.end()
+ if data:
+ await self._write(data)
+ await self.writer.finish(exc)
+ self._finished = True
+ else:
+ await self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class AsyncWriter(ABC):
+ """
+ A class to write copy data somewhere (for async connections).
+ """
+
+ @abstractmethod
+ async def write(self, data: Buffer) -> None:
+ ...
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ pass
+
+
+class AsyncLibpqWriter(AsyncWriter):
+ """
+ An `AsyncWriter` to write copy data to a Postgres database.
+ """
+
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ async def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ await self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ await self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = await self.connection.wait(copy_end(self._pgconn, bmsg))
+ self.cursor._results = [res]
+
+
+class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
+ """
+ An `AsyncWriter` using a buffer to queue data to write.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ super().__init__(cursor)
+
+ self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[asyncio.Future[None]] = None
+
+ async def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value.
+
+ The function is designed to be run in a separate task.
+ """
+ while True:
+ data = await self._queue.get()
+ if not data:
+ break
+ await self.connection.wait(copy_to(self._pgconn, data))
+
+ async def write(self, data: Buffer) -> None:
+ if not self._worker:
+ self._worker = create_task(self.worker())
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ await self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ await self._queue.put(b"")
+
+ if self._worker:
+ await asyncio.gather(self._worker)
+ self._worker = None # break reference loops if any
+
+ await super().finish(exc)
+
+
+class Formatter(ABC):
+ """
+ A class which understand a copy format (text, binary).
+ """
+
+ format: pq.Format
+
+ def __init__(self, transformer: Transformer):
+ self.transformer = transformer
+ self._write_buffer = bytearray()
+ self._row_mode = False # true if the user is using write_row()
+
+ @abstractmethod
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ ...
+
+ @abstractmethod
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ ...
+
+ @abstractmethod
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ ...
+
+ @abstractmethod
+ def end(self) -> Buffer:
+ ...
+
+
+class TextFormatter(Formatter):
+
+ format = TEXT
+
+ def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
+ super().__init__(transformer)
+ self._encoding = encoding
+
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ if data:
+ return parse_row_text(data, self.transformer)
+ else:
+ return None
+
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ format_row_text(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> Buffer:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
+ return data.encode(self._encoding)
+ else:
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
+
+
+class BinaryFormatter(Formatter):
+
+ format = BINARY
+
+ def __init__(self, transformer: Transformer):
+ super().__init__(transformer)
+ self._signature_sent = False
+
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ if not self._signature_sent:
+ if data[: len(_binary_signature)] != _binary_signature:
+ raise e.DataError(
+ "binary copy doesn't start with the expected signature"
+ )
+ self._signature_sent = True
+ data = data[len(_binary_signature) :]
+
+ elif data == _binary_trailer:
+ return None
+
+ return parse_row_binary(data, self.transformer)
+
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._signature_sent = True
+
+ format_row_binary(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> Buffer:
+ # If we have sent no data we need to send the signature
+ # and the trailer
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._write_buffer += _binary_trailer
+
+ elif self._row_mode:
+ # if we have sent data already, we have sent the signature
+ # too (either with the first row, or we assume that in
+ # block mode the signature is included).
+ # Write the trailer only if we are sending rows (with the
+ # assumption that who is copying binary data is sending the
+ # whole format).
+ self._write_buffer += _binary_trailer
+
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
+ raise TypeError("cannot copy str data in binary mode: use bytes instead")
+ else:
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
+
+
+def _format_row_text(
+ row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+ """Convert a row of objects to the data to send for copy."""
+ if out is None:
+ out = bytearray()
+
+ if not row:
+ out += b"\n"
+ return out
+
+ for item in row:
+ if item is not None:
+ dumper = tx.get_dumper(item, PY_TEXT)
+ b = dumper.dump(item)
+ out += _dump_re.sub(_dump_sub, b)
+ else:
+ out += rb"\N"
+ out += b"\t"
+
+ out[-1:] = b"\n"
+ return out
+
+
+def _format_row_binary(
+ row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+ """Convert a row of objects to the data to send for binary copy."""
+ if out is None:
+ out = bytearray()
+
+ out += _pack_int2(len(row))
+ adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
+ for b in adapted:
+ if b is not None:
+ out += _pack_int4(len(b))
+ out += b
+ else:
+ out += _binary_null
+
+ return out
+
+
+def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ fields = data.split(b"\t")
+ fields[-1] = fields[-1][:-1] # drop \n
+ row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
+ return tx.load_sequence(row)
+
+
+def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+ row: List[Optional[Buffer]] = []
+ nfields = _unpack_int2(data, 0)[0]
+ pos = 2
+ for i in range(nfields):
+ length = _unpack_int4(data, pos)[0]
+ pos += 4
+ if length >= 0:
+ row.append(data[pos : pos + length])
+ pos += length
+ else:
+ row.append(None)
+
+ return tx.load_sequence(row)
+
+
+_pack_int2 = struct.Struct("!h").pack
+_pack_int4 = struct.Struct("!i").pack
+_unpack_int2 = struct.Struct("!h").unpack_from
+_unpack_int4 = struct.Struct("!i").unpack_from
+
+_binary_signature = (
+ b"PGCOPY\n\xff\r\n\0" # Signature
+ b"\x00\x00\x00\x00" # flags
+ b"\x00\x00\x00\x00" # extra length
+)
+_binary_trailer = b"\xff\xff"
+_binary_null = b"\xff\xff\xff\xff"
+
+_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
+_dump_repl = {
+ b"\b": b"\\b",
+ b"\t": b"\\t",
+ b"\n": b"\\n",
+ b"\v": b"\\v",
+ b"\f": b"\\f",
+ b"\r": b"\\r",
+ b"\\": b"\\\\",
+}
+
+
+def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
+ return __map[m.group(0)]
+
+
+_load_re = re.compile(b"\\\\[btnvfr\\\\]")
+_load_repl = {v: k for k, v in _dump_repl.items()}
+
+
+def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
+ return __map[m.group(0)]
+
+
+# Override functions with fast versions if available
+if _psycopg:
+ format_row_text = _psycopg.format_row_text
+ format_row_binary = _psycopg.format_row_binary
+ parse_row_text = _psycopg.parse_row_text
+ parse_row_binary = _psycopg.parse_row_binary
+
+else:
+ format_row_text = _format_row_text
+ format_row_binary = _format_row_binary
+ parse_row_text = _parse_row_text
+ parse_row_binary = _parse_row_binary
diff --git a/psycopg/psycopg/crdb/__init__.py b/psycopg/psycopg/crdb/__init__.py
new file mode 100644
index 0000000..323903a
--- /dev/null
+++ b/psycopg/psycopg/crdb/__init__.py
@@ -0,0 +1,19 @@
+"""
+CockroachDB support package.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from . import _types
+from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
+
+adapters = _types.adapters # exposed by the package
+connect = CrdbConnection.connect
+
+_types.register_crdb_adapters(adapters)
+
+__all__ = [
+ "AsyncCrdbConnection",
+ "CrdbConnection",
+ "CrdbConnectionInfo",
+]
diff --git a/psycopg/psycopg/crdb/_types.py b/psycopg/psycopg/crdb/_types.py
new file mode 100644
index 0000000..5311e05
--- /dev/null
+++ b/psycopg/psycopg/crdb/_types.py
@@ -0,0 +1,163 @@
+"""
+Types configuration specific for CockroachDB.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from enum import Enum
+from .._typeinfo import TypeInfo, TypesRegistry
+
+from ..abc import AdaptContext, NoneType
+from ..postgres import TEXT_OID
+from .._adapters_map import AdaptersMap
+from ..types.enum import EnumDumper, EnumBinaryDumper
+from ..types.none import NoneDumper
+
+types = TypesRegistry()
+
+# Global adapter maps with PostgreSQL types configuration
+adapters = AdaptersMap(types=types)
+
+
+class CrdbEnumDumper(EnumDumper):
+ oid = TEXT_OID
+
+
+class CrdbEnumBinaryDumper(EnumBinaryDumper):
+ oid = TEXT_OID
+
+
+class CrdbNoneDumper(NoneDumper):
+ oid = TEXT_OID
+
+
+def register_postgres_adapters(context: AdaptContext) -> None:
+ # Same adapters used by PostgreSQL, or a good starting point for customization
+
+ from ..types import array, bool, composite, datetime
+ from ..types import numeric, string, uuid
+
+ array.register_default_adapters(context)
+ bool.register_default_adapters(context)
+ composite.register_default_adapters(context)
+ datetime.register_default_adapters(context)
+ numeric.register_default_adapters(context)
+ string.register_default_adapters(context)
+ uuid.register_default_adapters(context)
+
+
+def register_crdb_adapters(context: AdaptContext) -> None:
+ from .. import dbapi20
+ from ..types import array
+
+ register_postgres_adapters(context)
+
+ # String must come after enum to map text oid -> string dumper
+ register_crdb_enum_adapters(context)
+ register_crdb_string_adapters(context)
+ register_crdb_json_adapters(context)
+ register_crdb_net_adapters(context)
+ register_crdb_none_adapters(context)
+
+ dbapi20.register_dbapi20_adapters(adapters)
+
+ array.register_all_arrays(adapters)
+
+
+def register_crdb_string_adapters(context: AdaptContext) -> None:
+ from ..types import string
+
+ # Dump strings with text oid instead of unknown.
+ # Unlike PostgreSQL, CRDB seems able to cast text to most types.
+ context.adapters.register_dumper(str, string.StrDumper)
+ context.adapters.register_dumper(str, string.StrBinaryDumper)
+
+
+def register_crdb_enum_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
+ context.adapters.register_dumper(Enum, CrdbEnumDumper)
+
+
+def register_crdb_json_adapters(context: AdaptContext) -> None:
+ from ..types import json
+
+ adapters = context.adapters
+
+ # CRDB doesn't have json/jsonb: both names map to the jsonb oid
+ adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
+ adapters.register_dumper(json.Json, json.JsonbDumper)
+
+ adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper)
+ adapters.register_dumper(json.Jsonb, json.JsonbDumper)
+
+ adapters.register_loader("json", json.JsonLoader)
+ adapters.register_loader("jsonb", json.JsonbLoader)
+ adapters.register_loader("json", json.JsonBinaryLoader)
+ adapters.register_loader("jsonb", json.JsonbBinaryLoader)
+
+
+def register_crdb_net_adapters(context: AdaptContext) -> None:
+ from ..types import net
+
+ adapters = context.adapters
+
+ adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper)
+ adapters.register_dumper(None, net.InetBinaryDumper)
+ adapters.register_loader("inet", net.InetLoader)
+ adapters.register_loader("inet", net.InetBinaryLoader)
+
+
+def register_crdb_none_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(NoneType, CrdbNoneDumper)
+
+
+for t in [
+ TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
+ TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
+ TypeInfo('"char"', 18, 1002), # special case, not generated
+ # autogenerated: start
+ # Generated from CockroachDB 22.1.0
+ TypeInfo("bit", 1560, 1561),
+ TypeInfo("bool", 16, 1000, regtype="boolean"),
+ TypeInfo("bpchar", 1042, 1014, regtype="character"),
+ TypeInfo("bytea", 17, 1001),
+ TypeInfo("date", 1082, 1182),
+ TypeInfo("float4", 700, 1021, regtype="real"),
+ TypeInfo("float8", 701, 1022, regtype="double precision"),
+ TypeInfo("inet", 869, 1041),
+ TypeInfo("int2", 21, 1005, regtype="smallint"),
+ TypeInfo("int2vector", 22, 1006),
+ TypeInfo("int4", 23, 1007),
+ TypeInfo("int8", 20, 1016, regtype="bigint"),
+ TypeInfo("interval", 1186, 1187),
+ TypeInfo("jsonb", 3802, 3807),
+ TypeInfo("name", 19, 1003),
+ TypeInfo("numeric", 1700, 1231),
+ TypeInfo("oid", 26, 1028),
+ TypeInfo("oidvector", 30, 1013),
+ TypeInfo("record", 2249, 2287),
+ TypeInfo("regclass", 2205, 2210),
+ TypeInfo("regnamespace", 4089, 4090),
+ TypeInfo("regproc", 24, 1008),
+ TypeInfo("regprocedure", 2202, 2207),
+ TypeInfo("regrole", 4096, 4097),
+ TypeInfo("regtype", 2206, 2211),
+ TypeInfo("text", 25, 1009),
+ TypeInfo("time", 1083, 1183, regtype="time without time zone"),
+ TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
+ TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
+ TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
+ TypeInfo("unknown", 705, 0),
+ TypeInfo("uuid", 2950, 2951),
+ TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
+ TypeInfo("varchar", 1043, 1015, regtype="character varying"),
+ # autogenerated: end
+]:
+ types.add(t)
diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py
new file mode 100644
index 0000000..6e79ed1
--- /dev/null
+++ b/psycopg/psycopg/crdb/connection.py
@@ -0,0 +1,186 @@
+"""
+CockroachDB-specific connections.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import re
+from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
+
+from .. import errors as e
+from ..abc import AdaptContext
+from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
+from ..conninfo import ConnectionInfo
+from ..connection import Connection
+from .._adapters_map import AdaptersMap
+from ..connection_async import AsyncConnection
+from ._types import adapters
+
+if TYPE_CHECKING:
+ from ..pq.abc import PGconn
+ from ..cursor import Cursor
+ from ..cursor_async import AsyncCursor
+
+
+class _CrdbConnectionMixin:
+
+ _adapters: Optional[AdaptersMap]
+ pgconn: "PGconn"
+
+ @classmethod
+ def is_crdb(
+ cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
+ ) -> bool:
+ """
+ Return `!True` if the server connected to `!conn` is CockroachDB.
+ """
+ if isinstance(conn, (Connection, AsyncConnection)):
+ conn = conn.pgconn
+
+ return bool(conn.parameter_status(b"crdb_version"))
+
+ @property
+ def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ # By default, use CockroachDB adapters map
+ self._adapters = AdaptersMap(adapters)
+
+ return self._adapters
+
+ @property
+ def info(self) -> "CrdbConnectionInfo":
+ return CrdbConnectionInfo(self.pgconn)
+
+ def _check_tpc(self) -> None:
+ if self.is_crdb(self.pgconn):
+ raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
+
+
+class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
+ """
+ Wrapper for a connection to a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ row_factory: RowFactory[Row],
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
+ """
+ Connect to a database server and return a new `CrdbConnection` instance.
+ """
+ return super().connect(conninfo, **kwargs) # type: ignore[return-value]
+
+
+class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
+ """
+ Wrapper for an async connection to a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: AsyncRowFactory[Row],
+ cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ async def connect(
+ cls, conninfo: str = "", **kwargs: Any
+ ) -> "AsyncCrdbConnection[Any]":
+ return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return]
+
+
+class CrdbConnectionInfo(ConnectionInfo):
+ """
+ `~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ @property
+ def vendor(self) -> str:
+ return "CockroachDB"
+
+ @property
+ def server_version(self) -> int:
+ """
+ Return the CockroachDB server version connected.
+
+ Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
+ """
+ sver = self.parameter_status("crdb_version")
+ if not sver:
+ raise e.InternalError("'crdb_version' parameter status not set")
+
+ ver = self.parse_crdb_version(sver)
+ if ver is None:
+ raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
+
+ return ver
+
+ @classmethod
+ def parse_crdb_version(self, sver: str) -> Optional[int]:
+ m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
+ if not m:
+ return None
+
+ return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py
new file mode 100644
index 0000000..42c3804
--- /dev/null
+++ b/psycopg/psycopg/cursor.py
@@ -0,0 +1,921 @@
+"""
+psycopg cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from functools import partial
+from types import TracebackType
+from typing import Any, Generic, Iterable, Iterator, List
+from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar
+from typing import overload, TYPE_CHECKING
+from contextlib import contextmanager
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import ConnectionType, Query, Params, PQGen
+from .copy import Copy, Writer as CopyWriter
+from .rows import Row, RowMaker, RowFactory
+from ._column import Column
+from ._queries import PostgresQuery, PostgresClientQuery
+from ._pipeline import Pipeline
+from ._encodings import pgconn_encoding
+from ._preparing import Prepare
+from .generators import execute, fetch, send
+
+if TYPE_CHECKING:
+ from .abc import Transformer
+ from .pq.abc import PGconn, PGresult
+ from .connection import Connection
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+
+class BaseCursor(Generic[ConnectionType, Row]):
+ __slots__ = """
+ _conn format _adapters arraysize _closed _results pgresult _pos
+ _iresult _rowcount _query _tx _last_query _row_factory _make_row
+ _pgconn _execmany_returning
+ __weakref__
+ """.split()
+
+ ExecStatus = pq.ExecStatus
+
+ _tx: "Transformer"
+ _make_row: RowMaker[Row]
+ _pgconn: "PGconn"
+
+ def __init__(self, connection: ConnectionType):
+ self._conn = connection
+ self.format = TEXT
+ self._pgconn = connection.pgconn
+ self._adapters = adapt.AdaptersMap(connection.adapters)
+ self.arraysize = 1
+ self._closed = False
+ self._last_query: Optional[Query] = None
+ self._reset()
+
+ def _reset(self, reset_query: bool = True) -> None:
+ self._results: List["PGresult"] = []
+ self.pgresult: Optional["PGresult"] = None
+ self._pos = 0
+ self._iresult = 0
+ self._rowcount = -1
+ self._query: Optional[PostgresQuery]
+ # None if executemany() not executing, True/False according to returning state
+ self._execmany_returning: Optional[bool] = None
+ if reset_query:
+ self._query = None
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._pgconn)
+ if self._closed:
+ status = "closed"
+ elif self.pgresult:
+ status = pq.ExecStatus(self.pgresult.status).name
+ else:
+ status = "no result"
+ return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
+
+ @property
+ def connection(self) -> ConnectionType:
+ """The connection this cursor is using."""
+ return self._conn
+
+ @property
+ def adapters(self) -> adapt.AdaptersMap:
+ return self._adapters
+
+ @property
+ def closed(self) -> bool:
+ """`True` if the cursor is closed."""
+ return self._closed
+
+ @property
+ def description(self) -> Optional[List[Column]]:
+ """
+ A list of `Column` objects describing the current resultset.
+
+ `!None` if the current resultset didn't return tuples.
+ """
+ res = self.pgresult
+
+ # We return columns if we have nfields, but also if we don't but
+ # the query said we got tuples (mostly to handle the super useful
+ # query "SELECT ;"
+ if res and (
+ res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
+ ):
+ return [Column(self, i) for i in range(res.nfields)]
+ else:
+ return None
+
+ @property
+ def rowcount(self) -> int:
+ """Number of records affected by the precedent operation."""
+ return self._rowcount
+
+ @property
+ def rownumber(self) -> Optional[int]:
+ """Index of the next row to fetch in the current result.
+
+ `!None` if there is no result to fetch.
+ """
+ tuples = self.pgresult and self.pgresult.status == TUPLES_OK
+ return self._pos if tuples else None
+
+ def setinputsizes(self, sizes: Sequence[Any]) -> None:
+ # no-op
+ pass
+
+ def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
+ # no-op
+ pass
+
+ def nextset(self) -> Optional[bool]:
+ """
+ Move to the result set of the next query executed through `executemany()`
+ or to the next result set if `execute()` returned more than one.
+
+ Return `!True` if a new result is available, which will be the one
+ methods `!fetch*()` will operate on.
+ """
+ if self._iresult < len(self._results) - 1:
+ self._select_current_result(self._iresult + 1)
+ return True
+ else:
+ return None
+
+ @property
+ def statusmessage(self) -> Optional[str]:
+ """
+ The command status tag from the last SQL command executed.
+
+ `!None` if the cursor doesn't have a result available.
+ """
+ msg = self.pgresult.command_status if self.pgresult else None
+ return msg.decode() if msg else None
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ raise NotImplementedError
+
+ #
+ # Generators for the high level operations on the cursor
+ #
+ # Like for sync/async connections, these are implemented as generators
+ # so that different concurrency strategies (threads,asyncio) can use their
+ # own way of waiting (or better, `connection.wait()`).
+ #
+
+ def _execute_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator implementing `Cursor.execute()`."""
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ results = yield from self._maybe_prepare_gen(
+ pgq, prepare=prepare, binary=binary
+ )
+ if self._conn._pipeline:
+ yield from self._conn._pipeline._communicate_gen()
+ else:
+ assert results is not None
+ self._check_results(results)
+ self._results = results
+ self._select_current_result(0)
+
+ self._last_query = query
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _executemany_gen_pipeline(
+ self, query: Query, params_seq: Iterable[Params], returning: bool
+ ) -> PQGen[None]:
+ """
+ Generator implementing `Cursor.executemany()` with pipelines available.
+ """
+ pipeline = self._conn._pipeline
+ assert pipeline
+
+ yield from self._start_query(query)
+ self._rowcount = 0
+
+ assert self._execmany_returning is None
+ self._execmany_returning = returning
+
+ first = True
+ for params in params_seq:
+ if first:
+ pgq = self._convert_query(query, params)
+ self._query = pgq
+ first = False
+ else:
+ pgq.dump(params)
+
+ yield from self._maybe_prepare_gen(pgq, prepare=True)
+ yield from pipeline._communicate_gen()
+
+ self._last_query = query
+
+ if returning:
+ yield from pipeline._fetch_gen(flush=True)
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _executemany_gen_no_pipeline(
+ self, query: Query, params_seq: Iterable[Params], returning: bool
+ ) -> PQGen[None]:
+ """
+ Generator implementing `Cursor.executemany()` with pipelines not available.
+ """
+ yield from self._start_query(query)
+ first = True
+ nrows = 0
+ for params in params_seq:
+ if first:
+ pgq = self._convert_query(query, params)
+ self._query = pgq
+ first = False
+ else:
+ pgq.dump(params)
+
+ results = yield from self._maybe_prepare_gen(pgq, prepare=True)
+ assert results is not None
+ self._check_results(results)
+ if returning:
+ self._results.extend(results)
+
+ for res in results:
+ nrows += res.command_tuples or 0
+
+ if self._results:
+ self._select_current_result(0)
+
+ # Override rowcount for the first result. Calls to nextset() will change
+ # it to the value of that result only, but we hope nobody will notice.
+ # You haven't read this comment.
+ self._rowcount = nrows
+ self._last_query = query
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _maybe_prepare_gen(
+ self,
+ pgq: PostgresQuery,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[Optional[List["PGresult"]]]:
+ # Check if the query is prepared or needs preparing
+ prep, name = self._get_prepared(pgq, prepare)
+ if prep is Prepare.NO:
+ # The query must be executed without preparing
+ self._execute_send(pgq, binary=binary)
+ else:
+ # If the query is not already prepared, prepare it.
+ if prep is Prepare.SHOULD:
+ self._send_prepare(name, pgq)
+ if not self._conn._pipeline:
+ (result,) = yield from execute(self._pgconn)
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ # Then execute it.
+ self._send_query_prepared(name, pgq, binary=binary)
+
+ # Update the prepare state of the query.
+ # If an operation requires to flush our prepared statements cache,
+ # it will be added to the maintenance commands to execute later.
+ key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
+
+ if self._conn._pipeline:
+ queued = None
+ if key is not None:
+ queued = (key, prep, name)
+ self._conn._pipeline.result_queue.append((self, queued))
+ return None
+
+ # run the query
+ results = yield from execute(self._pgconn)
+
+ if key is not None:
+ self._conn._prepared.validate(key, prep, name, results)
+
+ return results
+
+ def _get_prepared(
+ self, pgq: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ return self._conn._prepared.get(pgq, prepare)
+
+ def _stream_send_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator to send the query for `Cursor.stream()`."""
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ self._execute_send(pgq, binary=binary, force_extended=True)
+ self._pgconn.set_single_row_mode()
+ self._last_query = query
+ yield from send(self._pgconn)
+
+ def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
+ res = yield from fetch(self._pgconn)
+ if res is None:
+ return None
+
+ status = res.status
+ if status == SINGLE_TUPLE:
+ self.pgresult = res
+ self._tx.set_pgresult(res, set_loaders=first)
+ if first:
+ self._make_row = self._make_row_maker()
+ return res
+
+ elif status == TUPLES_OK or status == COMMAND_OK:
+ # End of single row results
+ while res:
+ res = yield from fetch(self._pgconn)
+ if status != TUPLES_OK:
+ raise e.ProgrammingError(
+ "the operation in stream() didn't produce a result"
+ )
+ return None
+
+ else:
+ # Errors, unexpected values
+ return self._raise_for_result(res)
+
+ def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
+ """Generator to start the processing of a query.
+
+ It is implemented as generator because it may send additional queries,
+ such as `begin`.
+ """
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+
+ self._reset()
+ if not self._last_query or (self._last_query is not query):
+ self._last_query = None
+ self._tx = adapt.Transformer(self)
+ yield from self._conn._start_query()
+
+ def _start_copy_gen(
+ self, statement: Query, params: Optional[Params] = None
+ ) -> PQGen[None]:
+ """Generator implementing sending a command for `Cursor.copy()."""
+
+ # The connection gets in an unrecoverable state if we attempt COPY in
+ # pipeline mode. Forbid it explicitly.
+ if self._conn._pipeline:
+ raise e.NotSupportedError("COPY cannot be used in pipeline mode")
+
+ yield from self._start_query()
+
+ # Merge the params client-side
+ if params:
+ pgq = PostgresClientQuery(self._tx)
+ pgq.convert(statement, params)
+ statement = pgq.query
+
+ query = self._convert_query(statement)
+
+ self._execute_send(query, binary=False)
+ results = yield from execute(self._pgconn)
+ if len(results) != 1:
+ raise e.ProgrammingError("COPY cannot be mixed with other operations")
+
+ self._check_copy_result(results[0])
+ self._results = results
+ self._select_current_result(0)
+
+ def _execute_send(
+ self,
+ query: PostgresQuery,
+ *,
+ force_extended: bool = False,
+ binary: Optional[bool] = None,
+ ) -> None:
+ """
+ Implement part of execute() before waiting common to sync and async.
+
+ This is not a generator, but a normal non-blocking function.
+ """
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ self._query = query
+
+ if self._conn._pipeline:
+ # In pipeline mode always use PQsendQueryParams - see #314
+ # Multiple statements in the same query are not allowed anyway.
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_params,
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
+ )
+ elif force_extended or query.params or fmt == BINARY:
+ self._pgconn.send_query_params(
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
+ else:
+ # If we can, let's use simple query protocol,
+ # as it can execute more than one statement in a single query.
+ self._pgconn.send_query(query.query)
+
+ def _convert_query(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PostgresQuery:
+ pgq = PostgresQuery(self._tx)
+ pgq.convert(query, params)
+ return pgq
+
+ def _check_results(self, results: List["PGresult"]) -> None:
+ """
+ Verify that the results of a query are valid.
+
+ Verify that the query returned at least one result and that they all
+ represent a valid result from the database.
+ """
+ if not results:
+ raise e.InternalError("got no result from the query")
+
+ for res in results:
+ status = res.status
+ if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
+ self._raise_for_result(res)
+
+ def _raise_for_result(self, result: "PGresult") -> NoReturn:
+ """
+ Raise an appropriate error message for an unexpected database result
+ """
+ status = result.status
+ if status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ elif status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+ raise e.ProgrammingError(
+ "COPY cannot be used with this method; use copy() instead"
+ )
+ else:
+ raise e.InternalError(
+ "unexpected result status from query:" f" {pq.ExecStatus(status).name}"
+ )
+
+ def _select_current_result(
+ self, i: int, format: Optional[pq.Format] = None
+ ) -> None:
+ """
+ Select one of the results in the cursor as the active one.
+ """
+ self._iresult = i
+ res = self.pgresult = self._results[i]
+
+ # Note: the only reason to override format is to correctly set
+ # binary loaders on server-side cursors, because send_describe_portal
+ # only returns a text result.
+ self._tx.set_pgresult(res, format=format)
+
+ self._pos = 0
+
+ if res.status == TUPLES_OK:
+ self._rowcount = self.pgresult.ntuples
+
+ # COPY_OUT has never info about nrows. We need such result for the
+ # columns in order to return a `description`, but not overwrite the
+ # cursor rowcount (which was set by the Copy object).
+ elif res.status != COPY_OUT:
+ nrows = self.pgresult.command_tuples
+ self._rowcount = nrows if nrows is not None else -1
+
+ self._make_row = self._make_row_maker()
+
+ def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
+ self._check_results(results)
+ first_batch = not self._results
+
+ if self._execmany_returning is None:
+ # Received from execute()
+ self._results.extend(results)
+ if first_batch:
+ self._select_current_result(0)
+
+ else:
+ # Received from executemany()
+ if self._execmany_returning:
+ self._results.extend(results)
+ if first_batch:
+ self._select_current_result(0)
+ self._rowcount = 0
+
+ # Override rowcount for the first result. Calls to nextset() will
+ # change it to the value of that result only, but we hope nobody
+ # will notice.
+ # You haven't read this comment.
+ if self._rowcount < 0:
+ self._rowcount = 0
+ for res in results:
+ self._rowcount += res.command_tuples or 0
+
+ def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_prepare,
+ name,
+ query.query,
+ param_types=query.types,
+ )
+ )
+ self._conn._pipeline.result_queue.append(None)
+ else:
+ self._pgconn.send_prepare(name, query.query, param_types=query.types)
+
+ def _send_query_prepared(
+ self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
+ ) -> None:
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_prepared,
+ name,
+ pgq.params,
+ param_formats=pgq.formats,
+ result_format=fmt,
+ )
+ )
+ else:
+ self._pgconn.send_query_prepared(
+ name, pgq.params, param_formats=pgq.formats, result_format=fmt
+ )
+
+ def _check_result_for_fetch(self) -> None:
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+ res = self.pgresult
+ if not res:
+ raise e.ProgrammingError("no result available")
+
+ status = res.status
+ if status == TUPLES_OK:
+ return
+ elif status == FATAL_ERROR:
+ raise e.error_from_result(res, encoding=self._encoding)
+ elif status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ else:
+ raise e.ProgrammingError("the last operation didn't produce a result")
+
+ def _check_copy_result(self, result: "PGresult") -> None:
+ """
+ Check that the value returned in a copy() operation is a legit COPY.
+ """
+ status = result.status
+ if status == COPY_IN or status == COPY_OUT:
+ return
+ elif status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ else:
+ raise e.ProgrammingError(
+ "copy() should be used only with COPY ... TO STDOUT or COPY ..."
+ f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
+ )
+
+ def _scroll(self, value: int, mode: str) -> None:
+ self._check_result_for_fetch()
+ assert self.pgresult
+ if mode == "relative":
+ newpos = self._pos + value
+ elif mode == "absolute":
+ newpos = value
+ else:
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
+ if not 0 <= newpos < self.pgresult.ntuples:
+ raise IndexError("position out of bound")
+ self._pos = newpos
+
+ def _close(self) -> None:
+ """Non-blocking part of closing. Common to sync/async."""
+ # Don't reset the query because it may be useful to investigate after
+ # an error.
+ self._reset(reset_query=False)
+ self._closed = True
+
+ @property
+ def _encoding(self) -> str:
+ return pgconn_encoding(self._pgconn)
+
+
+class Cursor(BaseCursor["Connection[Any]", Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="Cursor[Any]")
+
+ @overload
+ def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
+ ...
+
+ @overload
+ def __init__(
+ self: "Cursor[Row]",
+ connection: "Connection[Any]",
+ *,
+ row_factory: RowFactory[Row],
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "Connection[Any]",
+ *,
+ row_factory: Optional[RowFactory[Row]] = None,
+ ):
+ super().__init__(connection)
+ self._row_factory = row_factory or connection.row_factory
+
+ def __enter__(self: _Self) -> _Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ self._close()
+
+ @property
+ def row_factory(self) -> RowFactory[Row]:
+ """Writable attribute to control how result rows are formed."""
+ return self._row_factory
+
+ @row_factory.setter
+ def row_factory(self, row_factory: RowFactory[Row]) -> None:
+ self._row_factory = row_factory
+ if self.pgresult:
+ self._make_row = row_factory(self)
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ return self._row_factory(self)
+
+ def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> _Self:
+ """
+ Execute a query or command to the database.
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ return self
+
+ def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = False,
+ ) -> None:
+ """
+ Execute the same command with a sequence of input data.
+ """
+ try:
+ if Pipeline.is_supported():
+ # If there is already a pipeline, ride it, in order to avoid
+ # sending unnecessary Sync.
+ with self._conn.lock:
+ p = self._conn._pipeline
+ if p:
+ self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ # Otherwise, make a new one
+ if not p:
+ with self._conn.pipeline(), self._conn.lock:
+ self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ else:
+ with self._conn.lock:
+ self._conn.wait(
+ self._executemany_gen_no_pipeline(query, params_seq, returning)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def stream(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> Iterator[Row]:
+ """
+ Iterate row-by-row on a result from the database.
+ """
+ if self._pgconn.pipeline_status:
+ raise e.ProgrammingError("stream() cannot be used in pipeline mode")
+
+ with self._conn.lock:
+
+ try:
+ self._conn.wait(self._stream_send_gen(query, params, binary=binary))
+ first = True
+ while self._conn.wait(self._stream_fetchone_gen(first)):
+ # We know that, if we got a result, it has a single row.
+ rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
+ yield rec
+ first = False
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ finally:
+ if self._pgconn.transaction_status == ACTIVE:
+ # Try to cancel the query, then consume the results
+ # already received.
+ self._conn.cancel()
+ try:
+ while self._conn.wait(self._stream_fetchone_gen(first=False)):
+ pass
+ except Exception:
+ pass
+
+ # Try to get out of ACTIVE state. Just do a single attempt, which
+ # should work to recover from an error or query cancelled.
+ try:
+ self._conn.wait(self._stream_fetchone_gen(first=False))
+ except Exception:
+ pass
+
+ def fetchone(self) -> Optional[Row]:
+ """
+ Return the next record from the current recordset.
+
+ Return `!None` the recordset is finished.
+
+ :rtype: Optional[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ record = self._tx.load_row(self._pos, self._make_row)
+ if record is not None:
+ self._pos += 1
+ return record
+
+ def fetchmany(self, size: int = 0) -> List[Row]:
+ """
+ Return the next `!size` records from the current recordset.
+
+ `!size` default to `!self.arraysize` if not specified.
+
+ :rtype: Sequence[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+
+ if not size:
+ size = self.arraysize
+ records = self._tx.load_rows(
+ self._pos,
+ min(self._pos + size, self.pgresult.ntuples),
+ self._make_row,
+ )
+ self._pos += len(records)
+ return records
+
+ def fetchall(self) -> List[Row]:
+ """
+ Return all the remaining records from the current recordset.
+
+ :rtype: Sequence[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
+ self._pos = self.pgresult.ntuples
+ return records
+
+ def __iter__(self) -> Iterator[Row]:
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+
+ def load(pos: int) -> Optional[Row]:
+ return self._tx.load_row(pos, self._make_row)
+
+ while True:
+ row = load(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ yield row
+
+ def scroll(self, value: int, mode: str = "relative") -> None:
+ """
+ Move the cursor in the result set to a new position according to mode.
+
+ If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
+ the current position in the result set; if set to ``'absolute'``,
+ `!value` states an absolute target position.
+
+ Raise `!IndexError` in case a scroll operation would leave the result
+ set. In this case the position will not change.
+ """
+ self._fetch_pipeline()
+ self._scroll(value, mode)
+
+ @contextmanager
+ def copy(
+ self,
+ statement: Query,
+ params: Optional[Params] = None,
+ *,
+ writer: Optional[CopyWriter] = None,
+ ) -> Iterator[Copy]:
+ """
+ Initiate a :sql:`COPY` operation and return an object to manage it.
+
+ :rtype: Copy
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._start_copy_gen(statement, params))
+
+ with Copy(self, writer=writer) as copy:
+ yield copy
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ # If a fresher result has been set on the cursor by the Copy object,
+ # read its properties (especially rowcount).
+ self._select_current_result(0)
+
+ def _fetch_pipeline(self) -> None:
+ if (
+ self._execmany_returning is not False
+ and not self.pgresult
+ and self._conn._pipeline
+ ):
+ with self._conn.lock:
+ self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py
new file mode 100644
index 0000000..8971d40
--- /dev/null
+++ b/psycopg/psycopg/cursor_async.py
@@ -0,0 +1,250 @@
+"""
+psycopg async cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from types import TracebackType
+from typing import Any, AsyncIterator, Iterable, List
+from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
+from contextlib import asynccontextmanager
+
+from . import pq
+from . import errors as e
+from .abc import Query, Params
+from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
+from .rows import Row, RowMaker, AsyncRowFactory
+from .cursor import BaseCursor
+from ._pipeline import Pipeline
+
+if TYPE_CHECKING:
+ from .connection_async import AsyncConnection
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+
+class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="AsyncCursor[Any]")
+
+ @overload
+ def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncCursor[Row]",
+ connection: "AsyncConnection[Any]",
+ *,
+ row_factory: AsyncRowFactory[Row],
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "AsyncConnection[Any]",
+ *,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ ):
+ super().__init__(connection)
+ self._row_factory = row_factory or connection.row_factory
+
+ async def __aenter__(self: _Self) -> _Self:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ async def close(self) -> None:
+ self._close()
+
+ @property
+ def row_factory(self) -> AsyncRowFactory[Row]:
+ return self._row_factory
+
+ @row_factory.setter
+ def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
+ self._row_factory = row_factory
+ if self.pgresult:
+ self._make_row = row_factory(self)
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ return self._row_factory(self)
+
+ async def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> _Self:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ return self
+
+ async def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = False,
+ ) -> None:
+ try:
+ if Pipeline.is_supported():
+ # If there is already a pipeline, ride it, in order to avoid
+ # sending unnecessary Sync.
+ async with self._conn.lock:
+ p = self._conn._pipeline
+ if p:
+ await self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ # Otherwise, make a new one
+ if not p:
+ async with self._conn.pipeline(), self._conn.lock:
+ await self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ else:
+ await self._conn.wait(
+ self._executemany_gen_no_pipeline(query, params_seq, returning)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def stream(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> AsyncIterator[Row]:
+ if self._pgconn.pipeline_status:
+ raise e.ProgrammingError("stream() cannot be used in pipeline mode")
+
+ async with self._conn.lock:
+
+ try:
+ await self._conn.wait(
+ self._stream_send_gen(query, params, binary=binary)
+ )
+ first = True
+ while await self._conn.wait(self._stream_fetchone_gen(first)):
+ # We know that, if we got a result, it has a single row.
+ rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
+ yield rec
+ first = False
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ finally:
+ if self._pgconn.transaction_status == ACTIVE:
+ # Try to cancel the query, then consume the results
+ # already received.
+ self._conn.cancel()
+ try:
+ while await self._conn.wait(
+ self._stream_fetchone_gen(first=False)
+ ):
+ pass
+ except Exception:
+ pass
+
+ # Try to get out of ACTIVE state. Just do a single attempt, which
+ # should work to recover from an error or query cancelled.
+ try:
+ await self._conn.wait(self._stream_fetchone_gen(first=False))
+ except Exception:
+ pass
+
+ async def fetchone(self) -> Optional[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ rv = self._tx.load_row(self._pos, self._make_row)
+ if rv is not None:
+ self._pos += 1
+ return rv
+
+ async def fetchmany(self, size: int = 0) -> List[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+
+ if not size:
+ size = self.arraysize
+ records = self._tx.load_rows(
+ self._pos,
+ min(self._pos + size, self.pgresult.ntuples),
+ self._make_row,
+ )
+ self._pos += len(records)
+ return records
+
+ async def fetchall(self) -> List[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
+ self._pos = self.pgresult.ntuples
+ return records
+
+ async def __aiter__(self) -> AsyncIterator[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+
+ def load(pos: int) -> Optional[Row]:
+ return self._tx.load_row(pos, self._make_row)
+
+ while True:
+ row = load(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ yield row
+
+ async def scroll(self, value: int, mode: str = "relative") -> None:
+ self._scroll(value, mode)
+
+ @asynccontextmanager
+ async def copy(
+ self,
+ statement: Query,
+ params: Optional[Params] = None,
+ *,
+ writer: Optional[AsyncCopyWriter] = None,
+ ) -> AsyncIterator[AsyncCopy]:
+ """
+ :rtype: AsyncCopy
+ """
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._start_copy_gen(statement, params))
+
+ async with AsyncCopy(self, writer=writer) as copy:
+ yield copy
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ self._select_current_result(0)
+
+ async def _fetch_pipeline(self) -> None:
+ if (
+ self._execmany_returning is not False
+ and not self.pgresult
+ and self._conn._pipeline
+ ):
+ async with self._conn.lock:
+ await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py
new file mode 100644
index 0000000..3c3d8b7
--- /dev/null
+++ b/psycopg/psycopg/dbapi20.py
@@ -0,0 +1,112 @@
+"""
+Compatibility objects with DBAPI 2.0
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import time
+import datetime as dt
+from math import floor
+from typing import Any, Sequence, Union
+
+from . import postgres
+from .abc import AdaptContext, Buffer
+from .types.string import BytesDumper, BytesBinaryDumper
+
+
+class DBAPITypeObject:
+ def __init__(self, name: str, type_names: Sequence[str]):
+ self.name = name
+ self.values = tuple(postgres.types[n].oid for n in type_names)
+
+ def __repr__(self) -> str:
+ return f"psycopg.{self.name}"
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, int):
+ return other in self.values
+ else:
+ return NotImplemented
+
+ def __ne__(self, other: Any) -> bool:
+ if isinstance(other, int):
+ return other not in self.values
+ else:
+ return NotImplemented
+
+
+BINARY = DBAPITypeObject("BINARY", ("bytea",))
+DATETIME = DBAPITypeObject(
+ "DATETIME", "timestamp timestamptz date time timetz interval".split()
+)
+NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
+ROWID = DBAPITypeObject("ROWID", ("oid",))
+STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
+
+
+class Binary:
+ def __init__(self, obj: Any):
+ self.obj = obj
+
+ def __repr__(self) -> str:
+ sobj = repr(self.obj)
+ if len(sobj) > 40:
+ sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
+ return f"{self.__class__.__name__}({sobj})"
+
+
+class BinaryBinaryDumper(BytesBinaryDumper):
+ def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+ if isinstance(obj, Binary):
+ return super().dump(obj.obj)
+ else:
+ return super().dump(obj)
+
+
+class BinaryTextDumper(BytesDumper):
+ def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+ if isinstance(obj, Binary):
+ return super().dump(obj.obj)
+ else:
+ return super().dump(obj)
+
+
+def Date(year: int, month: int, day: int) -> dt.date:
+ return dt.date(year, month, day)
+
+
+def DateFromTicks(ticks: float) -> dt.date:
+ return TimestampFromTicks(ticks).date()
+
+
+def Time(hour: int, minute: int, second: int) -> dt.time:
+ return dt.time(hour, minute, second)
+
+
+def TimeFromTicks(ticks: float) -> dt.time:
+ return TimestampFromTicks(ticks).time()
+
+
+def Timestamp(
+ year: int, month: int, day: int, hour: int, minute: int, second: int
+) -> dt.datetime:
+ return dt.datetime(year, month, day, hour, minute, second)
+
+
+def TimestampFromTicks(ticks: float) -> dt.datetime:
+ secs = floor(ticks)
+ frac = ticks - secs
+ t = time.localtime(ticks)
+ tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
+ rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
+ return rv
+
+
+def register_dbapi20_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Binary, BinaryTextDumper)
+ adapters.register_dumper(Binary, BinaryBinaryDumper)
+
+ # Make them also the default dumpers when dumping by bytea oid
+ adapters.register_dumper(None, BinaryTextDumper)
+ adapters.register_dumper(None, BinaryBinaryDumper)
diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py
new file mode 100644
index 0000000..e176954
--- /dev/null
+++ b/psycopg/psycopg/errors.py
@@ -0,0 +1,1535 @@
+"""
+psycopg exceptions
+
+DBAPI-defined Exceptions are defined in the following hierarchy::
+
+ Exceptions
+ |__Warning
+ |__Error
+ |__InterfaceError
+ |__DatabaseError
+ |__DataError
+ |__OperationalError
+ |__IntegrityError
+ |__InternalError
+ |__ProgrammingError
+ |__NotSupportedError
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
+from typing_extensions import TypeAlias
+
+from .pq.abc import PGconn, PGresult
+from .pq._enums import DiagnosticField
+from ._compat import TypeGuard
+
+ErrorInfo: TypeAlias = Union[None, PGresult, Dict[int, Optional[bytes]]]
+
+_sqlcodes: Dict[str, "Type[Error]"] = {}
+
+
+class Warning(Exception):
+ """
+ Exception raised for important warnings.
+
+ Defined for DBAPI compatibility, but never raised by ``psycopg``.
+ """
+
+ __module__ = "psycopg"
+
+
+class Error(Exception):
+ """
+ Base exception for all the errors psycopg will raise.
+
+ Exception that is the base class of all other error exceptions. You can
+ use this to catch all errors with one single `!except` statement.
+
+ This exception is guaranteed to be picklable.
+ """
+
+ __module__ = "psycopg"
+
+ sqlstate: Optional[str] = None
+
+ def __init__(
+ self,
+ *args: Sequence[Any],
+ info: ErrorInfo = None,
+ encoding: str = "utf-8",
+ pgconn: Optional[PGconn] = None
+ ):
+ super().__init__(*args)
+ self._info = info
+ self._encoding = encoding
+ self._pgconn = pgconn
+
+ # Handle sqlstate codes for which we don't have a class.
+ if not self.sqlstate and info:
+ self.sqlstate = self.diag.sqlstate
+
+ @property
+ def pgconn(self) -> Optional[PGconn]:
+ """The connection object, if the error was raised from a connection attempt.
+
+ :rtype: Optional[psycopg.pq.PGconn]
+ """
+ return self._pgconn if self._pgconn else None
+
+ @property
+ def pgresult(self) -> Optional[PGresult]:
+ """The result object, if the exception was raised after a failed query.
+
+ :rtype: Optional[psycopg.pq.PGresult]
+ """
+ return self._info if _is_pgresult(self._info) else None
+
+ @property
+ def diag(self) -> "Diagnostic":
+ """
+ A `Diagnostic` object to inspect details of the errors from the database.
+ """
+ return Diagnostic(self._info, encoding=self._encoding)
+
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+ res = super().__reduce__()
+ if isinstance(res, tuple) and len(res) >= 3:
+ # To make the exception picklable
+ res[2]["_info"] = _info_to_dict(self._info)
+ res[2]["_pgconn"] = None
+
+ return res
+
+
+class InterfaceError(Error):
+ """
+ An error related to the database interface rather than the database itself.
+ """
+
+ __module__ = "psycopg"
+
+
+class DatabaseError(Error):
+ """
+ Exception raised for errors that are related to the database.
+ """
+
+ __module__ = "psycopg"
+
+ def __init_subclass__(cls, code: Optional[str] = None, name: Optional[str] = None):
+ if code:
+ _sqlcodes[code] = cls
+ cls.sqlstate = code
+ if name:
+ _sqlcodes[name] = cls
+
+
+class DataError(DatabaseError):
+ """
+ An error caused by problems with the processed data.
+
+ Examples may be division by zero, numeric value out of range, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class OperationalError(DatabaseError):
+ """
+ An error related to the database's operation.
+
+ These errors are not necessarily under the control of the programmer, e.g.
+ an unexpected disconnect occurs, the data source name is not found, a
+ transaction could not be processed, a memory allocation error occurred
+ during processing, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class IntegrityError(DatabaseError):
+ """
+ An error caused when the relational integrity of the database is affected.
+
+ An example may be a foreign key check failed.
+ """
+
+ __module__ = "psycopg"
+
+
+class InternalError(DatabaseError):
+ """
+ An error generated when the database encounters an internal error,
+
+ Examples could be the cursor is not valid anymore, the transaction is out
+ of sync, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class ProgrammingError(DatabaseError):
+ """
+ Exception raised for programming errors
+
+ Examples may be table not found or already exists, syntax error in the SQL
+ statement, wrong number of parameters specified, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class NotSupportedError(DatabaseError):
+ """
+ A method or database API was used which is not supported by the database.
+ """
+
+ __module__ = "psycopg"
+
+
+class ConnectionTimeout(OperationalError):
+ """
+ Exception raised on timeout of the `~psycopg.Connection.connect()` method.
+
+ The error is raised if the ``connect_timeout`` is specified and a
+ connection is not obtained in useful time.
+
+ Subclass of `~psycopg.OperationalError`.
+ """
+
+
+class PipelineAborted(OperationalError):
+ """
+ Raised when a operation fails because the current pipeline is in aborted state.
+
+ Subclass of `~psycopg.OperationalError`.
+ """
+
+
+class Diagnostic:
+ """Details from a database error report."""
+
+ def __init__(self, info: ErrorInfo, encoding: str = "utf-8"):
+ self._info = info
+ self._encoding = encoding
+
+ @property
+ def severity(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SEVERITY)
+
+ @property
+ def severity_nonlocalized(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED)
+
+ @property
+ def sqlstate(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SQLSTATE)
+
+ @property
+ def message_primary(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_PRIMARY)
+
+ @property
+ def message_detail(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_DETAIL)
+
+ @property
+ def message_hint(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_HINT)
+
+ @property
+ def statement_position(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.STATEMENT_POSITION)
+
+ @property
+ def internal_position(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.INTERNAL_POSITION)
+
+ @property
+ def internal_query(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.INTERNAL_QUERY)
+
+ @property
+ def context(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.CONTEXT)
+
+ @property
+ def schema_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SCHEMA_NAME)
+
+ @property
+ def table_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.TABLE_NAME)
+
+ @property
+ def column_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.COLUMN_NAME)
+
+ @property
+ def datatype_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.DATATYPE_NAME)
+
+ @property
+ def constraint_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.CONSTRAINT_NAME)
+
+ @property
+ def source_file(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_FILE)
+
+ @property
+ def source_line(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_LINE)
+
+ @property
+ def source_function(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_FUNCTION)
+
+ def _error_message(self, field: DiagnosticField) -> Optional[str]:
+ if self._info:
+ if isinstance(self._info, dict):
+ val = self._info.get(field)
+ else:
+ val = self._info.error_field(field)
+
+ if val is not None:
+ return val.decode(self._encoding, "replace")
+
+ return None
+
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+ res = super().__reduce__()
+ if isinstance(res, tuple) and len(res) >= 3:
+ res[2]["_info"] = _info_to_dict(self._info)
+
+ return res
+
+
+def _info_to_dict(info: ErrorInfo) -> ErrorInfo:
+ """
+ Convert a PGresult to a dictionary to make the info picklable.
+ """
+ # PGresult is a protocol, can't use isinstance
+ if _is_pgresult(info):
+ return {v: info.error_field(v) for v in DiagnosticField}
+ else:
+ return info
+
+
+def lookup(sqlstate: str) -> Type[Error]:
+ """Lookup an error code or `constant name`__ and return its exception class.
+
+ Raise `!KeyError` if the code is not found.
+
+ .. __: https://www.postgresql.org/docs/current/errcodes-appendix.html
+ #ERRCODES-TABLE
+ """
+ return _sqlcodes[sqlstate.upper()]
+
+
+def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error:
+ from psycopg import pq
+
+ state = result.error_field(DiagnosticField.SQLSTATE) or b""
+ cls = _class_for_state(state.decode("ascii"))
+ return cls(
+ pq.error_message(result, encoding=encoding),
+ info=result,
+ encoding=encoding,
+ )
+
+
+def _is_pgresult(info: ErrorInfo) -> TypeGuard[PGresult]:
+ """Return True if an ErrorInfo is a PGresult instance."""
+ # PGresult is a protocol, can't use isinstance
+ return hasattr(info, "error_field")
+
+
+def _class_for_state(sqlstate: str) -> Type[Error]:
+ try:
+ return lookup(sqlstate)
+ except KeyError:
+ return get_base_exception(sqlstate)
+
+
+def get_base_exception(sqlstate: str) -> Type[Error]:
+ return (
+ _base_exc_map.get(sqlstate[:2])
+ or _base_exc_map.get(sqlstate[:1])
+ or DatabaseError
+ )
+
+
+_base_exc_map = {
+ "08": OperationalError, # Connection Exception
+ "0A": NotSupportedError, # Feature Not Supported
+ "20": ProgrammingError, # Case Not Foud
+ "21": ProgrammingError, # Cardinality Violation
+ "22": DataError, # Data Exception
+ "23": IntegrityError, # Integrity Constraint Violation
+ "24": InternalError, # Invalid Cursor State
+ "25": InternalError, # Invalid Transaction State
+ "26": ProgrammingError, # Invalid SQL Statement Name *
+ "27": OperationalError, # Triggered Data Change Violation
+ "28": OperationalError, # Invalid Authorization Specification
+ "2B": InternalError, # Dependent Privilege Descriptors Still Exist
+ "2D": InternalError, # Invalid Transaction Termination
+ "2F": OperationalError, # SQL Routine Exception *
+ "34": ProgrammingError, # Invalid Cursor Name *
+ "38": OperationalError, # External Routine Exception *
+ "39": OperationalError, # External Routine Invocation Exception *
+ "3B": OperationalError, # Savepoint Exception *
+ "3D": ProgrammingError, # Invalid Catalog Name
+ "3F": ProgrammingError, # Invalid Schema Name
+ "40": OperationalError, # Transaction Rollback
+ "42": ProgrammingError, # Syntax Error or Access Rule Violation
+ "44": ProgrammingError, # WITH CHECK OPTION Violation
+ "53": OperationalError, # Insufficient Resources
+ "54": OperationalError, # Program Limit Exceeded
+ "55": OperationalError, # Object Not In Prerequisite State
+ "57": OperationalError, # Operator Intervention
+ "58": OperationalError, # System Error (errors external to PostgreSQL itself)
+ "F": OperationalError, # Configuration File Error
+ "H": OperationalError, # Foreign Data Wrapper Error (SQL/MED)
+ "P": ProgrammingError, # PL/pgSQL Error
+ "X": InternalError, # Internal Error
+}
+
+
+# Error classes generated by tools/update_errors.py
+
+# fmt: off
+# autogenerated: start
+
+
+# Class 02 - No Data (this is also a warning class per the SQL standard)
+
+class NoData(DatabaseError,
+ code='02000', name='NO_DATA'):
+ pass
+
+class NoAdditionalDynamicResultSetsReturned(DatabaseError,
+ code='02001', name='NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED'):
+ pass
+
+
+# Class 03 - SQL Statement Not Yet Complete
+
+class SqlStatementNotYetComplete(DatabaseError,
+ code='03000', name='SQL_STATEMENT_NOT_YET_COMPLETE'):
+ pass
+
+
+# Class 08 - Connection Exception
+
+class ConnectionException(OperationalError,
+ code='08000', name='CONNECTION_EXCEPTION'):
+ pass
+
+class SqlclientUnableToEstablishSqlconnection(OperationalError,
+ code='08001', name='SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION'):
+ pass
+
+class ConnectionDoesNotExist(OperationalError,
+ code='08003', name='CONNECTION_DOES_NOT_EXIST'):
+ pass
+
+class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError,
+ code='08004', name='SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION'):
+ pass
+
+class ConnectionFailure(OperationalError,
+ code='08006', name='CONNECTION_FAILURE'):
+ pass
+
+class TransactionResolutionUnknown(OperationalError,
+ code='08007', name='TRANSACTION_RESOLUTION_UNKNOWN'):
+ pass
+
+class ProtocolViolation(OperationalError,
+ code='08P01', name='PROTOCOL_VIOLATION'):
+ pass
+
+
+# Class 09 - Triggered Action Exception
+
+class TriggeredActionException(DatabaseError,
+ code='09000', name='TRIGGERED_ACTION_EXCEPTION'):
+ pass
+
+
+# Class 0A - Feature Not Supported
+
+class FeatureNotSupported(NotSupportedError,
+ code='0A000', name='FEATURE_NOT_SUPPORTED'):
+ pass
+
+
+# Class 0B - Invalid Transaction Initiation
+
+class InvalidTransactionInitiation(DatabaseError,
+ code='0B000', name='INVALID_TRANSACTION_INITIATION'):
+ pass
+
+
+# Class 0F - Locator Exception
+
+class LocatorException(DatabaseError,
+ code='0F000', name='LOCATOR_EXCEPTION'):
+ pass
+
+class InvalidLocatorSpecification(DatabaseError,
+ code='0F001', name='INVALID_LOCATOR_SPECIFICATION'):
+ pass
+
+
+# Class 0L - Invalid Grantor
+
+class InvalidGrantor(DatabaseError,
+ code='0L000', name='INVALID_GRANTOR'):
+ pass
+
+class InvalidGrantOperation(DatabaseError,
+ code='0LP01', name='INVALID_GRANT_OPERATION'):
+ pass
+
+
+# Class 0P - Invalid Role Specification
+
+class InvalidRoleSpecification(DatabaseError,
+ code='0P000', name='INVALID_ROLE_SPECIFICATION'):
+ pass
+
+
+# Class 0Z - Diagnostics Exception
+
+class DiagnosticsException(DatabaseError,
+ code='0Z000', name='DIAGNOSTICS_EXCEPTION'):
+ pass
+
+class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError,
+ code='0Z002', name='STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER'):
+ pass
+
+
+# Class 20 - Case Not Found
+
+class CaseNotFound(ProgrammingError,
+ code='20000', name='CASE_NOT_FOUND'):
+ pass
+
+
+# Class 21 - Cardinality Violation
+
+class CardinalityViolation(ProgrammingError,
+ code='21000', name='CARDINALITY_VIOLATION'):
+ pass
+
+
+# Class 22 - Data Exception
+
+class DataException(DataError,
+ code='22000', name='DATA_EXCEPTION'):
+ pass
+
+class StringDataRightTruncation(DataError,
+ code='22001', name='STRING_DATA_RIGHT_TRUNCATION'):
+ pass
+
+class NullValueNoIndicatorParameter(DataError,
+ code='22002', name='NULL_VALUE_NO_INDICATOR_PARAMETER'):
+ pass
+
+class NumericValueOutOfRange(DataError,
+ code='22003', name='NUMERIC_VALUE_OUT_OF_RANGE'):
+ pass
+
+class NullValueNotAllowed(DataError,
+ code='22004', name='NULL_VALUE_NOT_ALLOWED'):
+ pass
+
+class ErrorInAssignment(DataError,
+ code='22005', name='ERROR_IN_ASSIGNMENT'):
+ pass
+
+class InvalidDatetimeFormat(DataError,
+ code='22007', name='INVALID_DATETIME_FORMAT'):
+ pass
+
+class DatetimeFieldOverflow(DataError,
+ code='22008', name='DATETIME_FIELD_OVERFLOW'):
+ pass
+
+class InvalidTimeZoneDisplacementValue(DataError,
+ code='22009', name='INVALID_TIME_ZONE_DISPLACEMENT_VALUE'):
+ pass
+
+class EscapeCharacterConflict(DataError,
+ code='2200B', name='ESCAPE_CHARACTER_CONFLICT'):
+ pass
+
+class InvalidUseOfEscapeCharacter(DataError,
+ code='2200C', name='INVALID_USE_OF_ESCAPE_CHARACTER'):
+ pass
+
+class InvalidEscapeOctet(DataError,
+ code='2200D', name='INVALID_ESCAPE_OCTET'):
+ pass
+
+class ZeroLengthCharacterString(DataError,
+ code='2200F', name='ZERO_LENGTH_CHARACTER_STRING'):
+ pass
+
+class MostSpecificTypeMismatch(DataError,
+ code='2200G', name='MOST_SPECIFIC_TYPE_MISMATCH'):
+ pass
+
+class SequenceGeneratorLimitExceeded(DataError,
+ code='2200H', name='SEQUENCE_GENERATOR_LIMIT_EXCEEDED'):
+ pass
+
+class NotAnXmlDocument(DataError,
+ code='2200L', name='NOT_AN_XML_DOCUMENT'):
+ pass
+
+class InvalidXmlDocument(DataError,
+ code='2200M', name='INVALID_XML_DOCUMENT'):
+ pass
+
+class InvalidXmlContent(DataError,
+ code='2200N', name='INVALID_XML_CONTENT'):
+ pass
+
+class InvalidXmlComment(DataError,
+ code='2200S', name='INVALID_XML_COMMENT'):
+ pass
+
+class InvalidXmlProcessingInstruction(DataError,
+ code='2200T', name='INVALID_XML_PROCESSING_INSTRUCTION'):
+ pass
+
+class InvalidIndicatorParameterValue(DataError,
+ code='22010', name='INVALID_INDICATOR_PARAMETER_VALUE'):
+ pass
+
+class SubstringError(DataError,
+ code='22011', name='SUBSTRING_ERROR'):
+ pass
+
+class DivisionByZero(DataError,
+ code='22012', name='DIVISION_BY_ZERO'):
+ pass
+
+class InvalidPrecedingOrFollowingSize(DataError,
+ code='22013', name='INVALID_PRECEDING_OR_FOLLOWING_SIZE'):
+ pass
+
+class InvalidArgumentForNtileFunction(DataError,
+ code='22014', name='INVALID_ARGUMENT_FOR_NTILE_FUNCTION'):
+ pass
+
+class IntervalFieldOverflow(DataError,
+ code='22015', name='INTERVAL_FIELD_OVERFLOW'):
+ pass
+
+class InvalidArgumentForNthValueFunction(DataError,
+ code='22016', name='INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION'):
+ pass
+
+class InvalidCharacterValueForCast(DataError,
+ code='22018', name='INVALID_CHARACTER_VALUE_FOR_CAST'):
+ pass
+
+class InvalidEscapeCharacter(DataError,
+ code='22019', name='INVALID_ESCAPE_CHARACTER'):
+ pass
+
+class InvalidRegularExpression(DataError,
+ code='2201B', name='INVALID_REGULAR_EXPRESSION'):
+ pass
+
+class InvalidArgumentForLogarithm(DataError,
+ code='2201E', name='INVALID_ARGUMENT_FOR_LOGARITHM'):
+ pass
+
+class InvalidArgumentForPowerFunction(DataError,
+ code='2201F', name='INVALID_ARGUMENT_FOR_POWER_FUNCTION'):
+ pass
+
+class InvalidArgumentForWidthBucketFunction(DataError,
+ code='2201G', name='INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION'):
+ pass
+
+class InvalidRowCountInLimitClause(DataError,
+ code='2201W', name='INVALID_ROW_COUNT_IN_LIMIT_CLAUSE'):
+ pass
+
+class InvalidRowCountInResultOffsetClause(DataError,
+ code='2201X', name='INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE'):
+ pass
+
+class CharacterNotInRepertoire(DataError,
+ code='22021', name='CHARACTER_NOT_IN_REPERTOIRE'):
+ pass
+
+class IndicatorOverflow(DataError,
+ code='22022', name='INDICATOR_OVERFLOW'):
+ pass
+
+class InvalidParameterValue(DataError,
+ code='22023', name='INVALID_PARAMETER_VALUE'):
+ pass
+
+class UnterminatedCString(DataError,
+ code='22024', name='UNTERMINATED_C_STRING'):
+ pass
+
+class InvalidEscapeSequence(DataError,
+ code='22025', name='INVALID_ESCAPE_SEQUENCE'):
+ pass
+
+class StringDataLengthMismatch(DataError,
+ code='22026', name='STRING_DATA_LENGTH_MISMATCH'):
+ pass
+
+class TrimError(DataError,
+ code='22027', name='TRIM_ERROR'):
+ pass
+
+class ArraySubscriptError(DataError,
+ code='2202E', name='ARRAY_SUBSCRIPT_ERROR'):
+ pass
+
+class InvalidTablesampleRepeat(DataError,
+ code='2202G', name='INVALID_TABLESAMPLE_REPEAT'):
+ pass
+
+class InvalidTablesampleArgument(DataError,
+ code='2202H', name='INVALID_TABLESAMPLE_ARGUMENT'):
+ pass
+
+class DuplicateJsonObjectKeyValue(DataError,
+ code='22030', name='DUPLICATE_JSON_OBJECT_KEY_VALUE'):
+ pass
+
+class InvalidArgumentForSqlJsonDatetimeFunction(DataError,
+ code='22031', name='INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION'):
+ pass
+
+class InvalidJsonText(DataError,
+ code='22032', name='INVALID_JSON_TEXT'):
+ pass
+
+class InvalidSqlJsonSubscript(DataError,
+ code='22033', name='INVALID_SQL_JSON_SUBSCRIPT'):
+ pass
+
+class MoreThanOneSqlJsonItem(DataError,
+ code='22034', name='MORE_THAN_ONE_SQL_JSON_ITEM'):
+ pass
+
+class NoSqlJsonItem(DataError,
+ code='22035', name='NO_SQL_JSON_ITEM'):
+ pass
+
+class NonNumericSqlJsonItem(DataError,
+ code='22036', name='NON_NUMERIC_SQL_JSON_ITEM'):
+ pass
+
+class NonUniqueKeysInAJsonObject(DataError,
+ code='22037', name='NON_UNIQUE_KEYS_IN_A_JSON_OBJECT'):
+ pass
+
+class SingletonSqlJsonItemRequired(DataError,
+ code='22038', name='SINGLETON_SQL_JSON_ITEM_REQUIRED'):
+ pass
+
+class SqlJsonArrayNotFound(DataError,
+ code='22039', name='SQL_JSON_ARRAY_NOT_FOUND'):
+ pass
+
+class SqlJsonMemberNotFound(DataError,
+ code='2203A', name='SQL_JSON_MEMBER_NOT_FOUND'):
+ pass
+
+class SqlJsonNumberNotFound(DataError,
+ code='2203B', name='SQL_JSON_NUMBER_NOT_FOUND'):
+ pass
+
+class SqlJsonObjectNotFound(DataError,
+ code='2203C', name='SQL_JSON_OBJECT_NOT_FOUND'):
+ pass
+
+class TooManyJsonArrayElements(DataError,
+ code='2203D', name='TOO_MANY_JSON_ARRAY_ELEMENTS'):
+ pass
+
+class TooManyJsonObjectMembers(DataError,
+ code='2203E', name='TOO_MANY_JSON_OBJECT_MEMBERS'):
+ pass
+
+class SqlJsonScalarRequired(DataError,
+ code='2203F', name='SQL_JSON_SCALAR_REQUIRED'):
+ pass
+
+class SqlJsonItemCannotBeCastToTargetType(DataError,
+ code='2203G', name='SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE'):
+ pass
+
+class FloatingPointException(DataError,
+ code='22P01', name='FLOATING_POINT_EXCEPTION'):
+ pass
+
+class InvalidTextRepresentation(DataError,
+ code='22P02', name='INVALID_TEXT_REPRESENTATION'):
+ pass
+
+class InvalidBinaryRepresentation(DataError,
+ code='22P03', name='INVALID_BINARY_REPRESENTATION'):
+ pass
+
+class BadCopyFileFormat(DataError,
+ code='22P04', name='BAD_COPY_FILE_FORMAT'):
+ pass
+
+class UntranslatableCharacter(DataError,
+ code='22P05', name='UNTRANSLATABLE_CHARACTER'):
+ pass
+
+class NonstandardUseOfEscapeCharacter(DataError,
+ code='22P06', name='NONSTANDARD_USE_OF_ESCAPE_CHARACTER'):
+ pass
+
+
+# Class 23 - Integrity Constraint Violation
+
+class IntegrityConstraintViolation(IntegrityError,
+ code='23000', name='INTEGRITY_CONSTRAINT_VIOLATION'):
+ pass
+
+class RestrictViolation(IntegrityError,
+ code='23001', name='RESTRICT_VIOLATION'):
+ pass
+
+class NotNullViolation(IntegrityError,
+ code='23502', name='NOT_NULL_VIOLATION'):
+ pass
+
+class ForeignKeyViolation(IntegrityError,
+ code='23503', name='FOREIGN_KEY_VIOLATION'):
+ pass
+
+class UniqueViolation(IntegrityError,
+ code='23505', name='UNIQUE_VIOLATION'):
+ pass
+
+class CheckViolation(IntegrityError,
+ code='23514', name='CHECK_VIOLATION'):
+ pass
+
+class ExclusionViolation(IntegrityError,
+ code='23P01', name='EXCLUSION_VIOLATION'):
+ pass
+
+
+# Class 24 - Invalid Cursor State
+
+class InvalidCursorState(InternalError,
+ code='24000', name='INVALID_CURSOR_STATE'):
+ pass
+
+
+# Class 25 - Invalid Transaction State
+
+class InvalidTransactionState(InternalError,
+ code='25000', name='INVALID_TRANSACTION_STATE'):
+ pass
+
+class ActiveSqlTransaction(InternalError,
+ code='25001', name='ACTIVE_SQL_TRANSACTION'):
+ pass
+
+class BranchTransactionAlreadyActive(InternalError,
+ code='25002', name='BRANCH_TRANSACTION_ALREADY_ACTIVE'):
+ pass
+
+class InappropriateAccessModeForBranchTransaction(InternalError,
+ code='25003', name='INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class InappropriateIsolationLevelForBranchTransaction(InternalError,
+ code='25004', name='INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class NoActiveSqlTransactionForBranchTransaction(InternalError,
+ code='25005', name='NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class ReadOnlySqlTransaction(InternalError,
+ code='25006', name='READ_ONLY_SQL_TRANSACTION'):
+ pass
+
+class SchemaAndDataStatementMixingNotSupported(InternalError,
+ code='25007', name='SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED'):
+ pass
+
+class HeldCursorRequiresSameIsolationLevel(InternalError,
+ code='25008', name='HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL'):
+ pass
+
+class NoActiveSqlTransaction(InternalError,
+ code='25P01', name='NO_ACTIVE_SQL_TRANSACTION'):
+ pass
+
+class InFailedSqlTransaction(InternalError,
+ code='25P02', name='IN_FAILED_SQL_TRANSACTION'):
+ pass
+
+class IdleInTransactionSessionTimeout(InternalError,
+ code='25P03', name='IDLE_IN_TRANSACTION_SESSION_TIMEOUT'):
+ pass
+
+
+# Class 26 - Invalid SQL Statement Name
+
+class InvalidSqlStatementName(ProgrammingError,
+ code='26000', name='INVALID_SQL_STATEMENT_NAME'):
+ pass
+
+
+# Class 27 - Triggered Data Change Violation
+
+class TriggeredDataChangeViolation(OperationalError,
+ code='27000', name='TRIGGERED_DATA_CHANGE_VIOLATION'):
+ pass
+
+
+# Class 28 - Invalid Authorization Specification
+
+class InvalidAuthorizationSpecification(OperationalError,
+ code='28000', name='INVALID_AUTHORIZATION_SPECIFICATION'):
+ pass
+
+class InvalidPassword(OperationalError,
+ code='28P01', name='INVALID_PASSWORD'):
+ pass
+
+
+# Class 2B - Dependent Privilege Descriptors Still Exist
+
+class DependentPrivilegeDescriptorsStillExist(InternalError,
+ code='2B000', name='DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST'):
+ pass
+
+class DependentObjectsStillExist(InternalError,
+ code='2BP01', name='DEPENDENT_OBJECTS_STILL_EXIST'):
+ pass
+
+
+# Class 2D - Invalid Transaction Termination
+
+class InvalidTransactionTermination(InternalError,
+ code='2D000', name='INVALID_TRANSACTION_TERMINATION'):
+ pass
+
+
+# Class 2F - SQL Routine Exception
+
+class SqlRoutineException(OperationalError,
+ code='2F000', name='SQL_ROUTINE_EXCEPTION'):
+ pass
+
+class ModifyingSqlDataNotPermitted(OperationalError,
+ code='2F002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class ProhibitedSqlStatementAttempted(OperationalError,
+ code='2F003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'):
+ pass
+
+class ReadingSqlDataNotPermitted(OperationalError,
+ code='2F004', name='READING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class FunctionExecutedNoReturnStatement(OperationalError,
+ code='2F005', name='FUNCTION_EXECUTED_NO_RETURN_STATEMENT'):
+ pass
+
+
+# Class 34 - Invalid Cursor Name
+
+class InvalidCursorName(ProgrammingError,
+ code='34000', name='INVALID_CURSOR_NAME'):
+ pass
+
+
+# Class 38 - External Routine Exception
+
+class ExternalRoutineException(OperationalError,
+ code='38000', name='EXTERNAL_ROUTINE_EXCEPTION'):
+ pass
+
+class ContainingSqlNotPermitted(OperationalError,
+ code='38001', name='CONTAINING_SQL_NOT_PERMITTED'):
+ pass
+
+class ModifyingSqlDataNotPermittedExt(OperationalError,
+ code='38002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class ProhibitedSqlStatementAttemptedExt(OperationalError,
+ code='38003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'):
+ pass
+
+class ReadingSqlDataNotPermittedExt(OperationalError,
+ code='38004', name='READING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+
+# Class 39 - External Routine Invocation Exception
+
+class ExternalRoutineInvocationException(OperationalError,
+ code='39000', name='EXTERNAL_ROUTINE_INVOCATION_EXCEPTION'):
+ pass
+
+class InvalidSqlstateReturned(OperationalError,
+ code='39001', name='INVALID_SQLSTATE_RETURNED'):
+ pass
+
+class NullValueNotAllowedExt(OperationalError,
+ code='39004', name='NULL_VALUE_NOT_ALLOWED'):
+ pass
+
+class TriggerProtocolViolated(OperationalError,
+ code='39P01', name='TRIGGER_PROTOCOL_VIOLATED'):
+ pass
+
+class SrfProtocolViolated(OperationalError,
+ code='39P02', name='SRF_PROTOCOL_VIOLATED'):
+ pass
+
+class EventTriggerProtocolViolated(OperationalError,
+ code='39P03', name='EVENT_TRIGGER_PROTOCOL_VIOLATED'):
+ pass
+
+
+# Class 3B - Savepoint Exception
+
+class SavepointException(OperationalError,
+ code='3B000', name='SAVEPOINT_EXCEPTION'):
+ pass
+
+class InvalidSavepointSpecification(OperationalError,
+ code='3B001', name='INVALID_SAVEPOINT_SPECIFICATION'):
+ pass
+
+
+# Class 3D - Invalid Catalog Name
+
+class InvalidCatalogName(ProgrammingError,
+ code='3D000', name='INVALID_CATALOG_NAME'):
+ pass
+
+
+# Class 3F - Invalid Schema Name
+
+class InvalidSchemaName(ProgrammingError,
+ code='3F000', name='INVALID_SCHEMA_NAME'):
+ pass
+
+
+# Class 40 - Transaction Rollback
+
+class TransactionRollback(OperationalError,
+ code='40000', name='TRANSACTION_ROLLBACK'):
+ pass
+
+class SerializationFailure(OperationalError,
+ code='40001', name='SERIALIZATION_FAILURE'):
+ pass
+
+class TransactionIntegrityConstraintViolation(OperationalError,
+ code='40002', name='TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION'):
+ pass
+
+class StatementCompletionUnknown(OperationalError,
+ code='40003', name='STATEMENT_COMPLETION_UNKNOWN'):
+ pass
+
+class DeadlockDetected(OperationalError,
+ code='40P01', name='DEADLOCK_DETECTED'):
+ pass
+
+
+# Class 42 - Syntax Error or Access Rule Violation
+
+class SyntaxErrorOrAccessRuleViolation(ProgrammingError,
+ code='42000', name='SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION'):
+ pass
+
+class InsufficientPrivilege(ProgrammingError,
+ code='42501', name='INSUFFICIENT_PRIVILEGE'):
+ pass
+
+class SyntaxError(ProgrammingError,
+ code='42601', name='SYNTAX_ERROR'):
+ pass
+
+class InvalidName(ProgrammingError,
+ code='42602', name='INVALID_NAME'):
+ pass
+
+class InvalidColumnDefinition(ProgrammingError,
+ code='42611', name='INVALID_COLUMN_DEFINITION'):
+ pass
+
+class NameTooLong(ProgrammingError,
+ code='42622', name='NAME_TOO_LONG'):
+ pass
+
+class DuplicateColumn(ProgrammingError,
+ code='42701', name='DUPLICATE_COLUMN'):
+ pass
+
+class AmbiguousColumn(ProgrammingError,
+ code='42702', name='AMBIGUOUS_COLUMN'):
+ pass
+
+class UndefinedColumn(ProgrammingError,
+ code='42703', name='UNDEFINED_COLUMN'):
+ pass
+
+class UndefinedObject(ProgrammingError,
+ code='42704', name='UNDEFINED_OBJECT'):
+ pass
+
+class DuplicateObject(ProgrammingError,
+ code='42710', name='DUPLICATE_OBJECT'):
+ pass
+
+class DuplicateAlias(ProgrammingError,
+ code='42712', name='DUPLICATE_ALIAS'):
+ pass
+
+class DuplicateFunction(ProgrammingError,
+ code='42723', name='DUPLICATE_FUNCTION'):
+ pass
+
+class AmbiguousFunction(ProgrammingError,
+ code='42725', name='AMBIGUOUS_FUNCTION'):
+ pass
+
+class GroupingError(ProgrammingError,
+ code='42803', name='GROUPING_ERROR'):
+ pass
+
+class DatatypeMismatch(ProgrammingError,
+ code='42804', name='DATATYPE_MISMATCH'):
+ pass
+
+class WrongObjectType(ProgrammingError,
+ code='42809', name='WRONG_OBJECT_TYPE'):
+ pass
+
+class InvalidForeignKey(ProgrammingError,
+ code='42830', name='INVALID_FOREIGN_KEY'):
+ pass
+
+class CannotCoerce(ProgrammingError,
+ code='42846', name='CANNOT_COERCE'):
+ pass
+
+class UndefinedFunction(ProgrammingError,
+ code='42883', name='UNDEFINED_FUNCTION'):
+ pass
+
+class GeneratedAlways(ProgrammingError,
+ code='428C9', name='GENERATED_ALWAYS'):
+ pass
+
+class ReservedName(ProgrammingError,
+ code='42939', name='RESERVED_NAME'):
+ pass
+
+class UndefinedTable(ProgrammingError,
+ code='42P01', name='UNDEFINED_TABLE'):
+ pass
+
+class UndefinedParameter(ProgrammingError,
+ code='42P02', name='UNDEFINED_PARAMETER'):
+ pass
+
+class DuplicateCursor(ProgrammingError,
+ code='42P03', name='DUPLICATE_CURSOR'):
+ pass
+
+class DuplicateDatabase(ProgrammingError,
+ code='42P04', name='DUPLICATE_DATABASE'):
+ pass
+
+class DuplicatePreparedStatement(ProgrammingError,
+ code='42P05', name='DUPLICATE_PREPARED_STATEMENT'):
+ pass
+
+class DuplicateSchema(ProgrammingError,
+ code='42P06', name='DUPLICATE_SCHEMA'):
+ pass
+
+class DuplicateTable(ProgrammingError,
+ code='42P07', name='DUPLICATE_TABLE'):
+ pass
+
+class AmbiguousParameter(ProgrammingError,
+ code='42P08', name='AMBIGUOUS_PARAMETER'):
+ pass
+
+class AmbiguousAlias(ProgrammingError,
+ code='42P09', name='AMBIGUOUS_ALIAS'):
+ pass
+
+class InvalidColumnReference(ProgrammingError,
+ code='42P10', name='INVALID_COLUMN_REFERENCE'):
+ pass
+
+class InvalidCursorDefinition(ProgrammingError,
+ code='42P11', name='INVALID_CURSOR_DEFINITION'):
+ pass
+
+class InvalidDatabaseDefinition(ProgrammingError,
+ code='42P12', name='INVALID_DATABASE_DEFINITION'):
+ pass
+
+class InvalidFunctionDefinition(ProgrammingError,
+ code='42P13', name='INVALID_FUNCTION_DEFINITION'):
+ pass
+
+class InvalidPreparedStatementDefinition(ProgrammingError,
+ code='42P14', name='INVALID_PREPARED_STATEMENT_DEFINITION'):
+ pass
+
+class InvalidSchemaDefinition(ProgrammingError,
+ code='42P15', name='INVALID_SCHEMA_DEFINITION'):
+ pass
+
+class InvalidTableDefinition(ProgrammingError,
+ code='42P16', name='INVALID_TABLE_DEFINITION'):
+ pass
+
+class InvalidObjectDefinition(ProgrammingError,
+ code='42P17', name='INVALID_OBJECT_DEFINITION'):
+ pass
+
+class IndeterminateDatatype(ProgrammingError,
+ code='42P18', name='INDETERMINATE_DATATYPE'):
+ pass
+
+class InvalidRecursion(ProgrammingError,
+ code='42P19', name='INVALID_RECURSION'):
+ pass
+
+class WindowingError(ProgrammingError,
+ code='42P20', name='WINDOWING_ERROR'):
+ pass
+
+class CollationMismatch(ProgrammingError,
+ code='42P21', name='COLLATION_MISMATCH'):
+ pass
+
+class IndeterminateCollation(ProgrammingError,
+ code='42P22', name='INDETERMINATE_COLLATION'):
+ pass
+
+
+# Class 44 - WITH CHECK OPTION Violation
+
+class WithCheckOptionViolation(ProgrammingError,
+ code='44000', name='WITH_CHECK_OPTION_VIOLATION'):
+ pass
+
+
+# Class 53 - Insufficient Resources
+
+class InsufficientResources(OperationalError,
+ code='53000', name='INSUFFICIENT_RESOURCES'):
+ pass
+
+class DiskFull(OperationalError,
+ code='53100', name='DISK_FULL'):
+ pass
+
+class OutOfMemory(OperationalError,
+ code='53200', name='OUT_OF_MEMORY'):
+ pass
+
+class TooManyConnections(OperationalError,
+ code='53300', name='TOO_MANY_CONNECTIONS'):
+ pass
+
+class ConfigurationLimitExceeded(OperationalError,
+ code='53400', name='CONFIGURATION_LIMIT_EXCEEDED'):
+ pass
+
+
+# Class 54 - Program Limit Exceeded
+
+class ProgramLimitExceeded(OperationalError,
+ code='54000', name='PROGRAM_LIMIT_EXCEEDED'):
+ pass
+
+class StatementTooComplex(OperationalError,
+ code='54001', name='STATEMENT_TOO_COMPLEX'):
+ pass
+
+class TooManyColumns(OperationalError,
+ code='54011', name='TOO_MANY_COLUMNS'):
+ pass
+
+class TooManyArguments(OperationalError,
+ code='54023', name='TOO_MANY_ARGUMENTS'):
+ pass
+
+
+# Class 55 - Object Not In Prerequisite State
+
+class ObjectNotInPrerequisiteState(OperationalError,
+ code='55000', name='OBJECT_NOT_IN_PREREQUISITE_STATE'):
+ pass
+
+class ObjectInUse(OperationalError,
+ code='55006', name='OBJECT_IN_USE'):
+ pass
+
+class CantChangeRuntimeParam(OperationalError,
+ code='55P02', name='CANT_CHANGE_RUNTIME_PARAM'):
+ pass
+
+class LockNotAvailable(OperationalError,
+ code='55P03', name='LOCK_NOT_AVAILABLE'):
+ pass
+
+class UnsafeNewEnumValueUsage(OperationalError,
+ code='55P04', name='UNSAFE_NEW_ENUM_VALUE_USAGE'):
+ pass
+
+
+# Class 57 - Operator Intervention
+
+class OperatorIntervention(OperationalError,
+ code='57000', name='OPERATOR_INTERVENTION'):
+ pass
+
+class QueryCanceled(OperationalError,
+ code='57014', name='QUERY_CANCELED'):
+ pass
+
+class AdminShutdown(OperationalError,
+ code='57P01', name='ADMIN_SHUTDOWN'):
+ pass
+
+class CrashShutdown(OperationalError,
+ code='57P02', name='CRASH_SHUTDOWN'):
+ pass
+
+class CannotConnectNow(OperationalError,
+ code='57P03', name='CANNOT_CONNECT_NOW'):
+ pass
+
+class DatabaseDropped(OperationalError,
+ code='57P04', name='DATABASE_DROPPED'):
+ pass
+
+class IdleSessionTimeout(OperationalError,
+ code='57P05', name='IDLE_SESSION_TIMEOUT'):
+ pass
+
+
+# Class 58 - System Error (errors external to PostgreSQL itself)
+
+class SystemError(OperationalError,
+ code='58000', name='SYSTEM_ERROR'):
+ pass
+
+class IoError(OperationalError,
+ code='58030', name='IO_ERROR'):
+ pass
+
+class UndefinedFile(OperationalError,
+ code='58P01', name='UNDEFINED_FILE'):
+ pass
+
+class DuplicateFile(OperationalError,
+ code='58P02', name='DUPLICATE_FILE'):
+ pass
+
+
+# Class 72 - Snapshot Failure
+
+class SnapshotTooOld(DatabaseError,
+ code='72000', name='SNAPSHOT_TOO_OLD'):
+ pass
+
+
+# Class F0 - Configuration File Error
+
+class ConfigFileError(OperationalError,
+ code='F0000', name='CONFIG_FILE_ERROR'):
+ pass
+
+class LockFileExists(OperationalError,
+ code='F0001', name='LOCK_FILE_EXISTS'):
+ pass
+
+
+# Class HV - Foreign Data Wrapper Error (SQL/MED)
+
+class FdwError(OperationalError,
+ code='HV000', name='FDW_ERROR'):
+ pass
+
+class FdwOutOfMemory(OperationalError,
+ code='HV001', name='FDW_OUT_OF_MEMORY'):
+ pass
+
+class FdwDynamicParameterValueNeeded(OperationalError,
+ code='HV002', name='FDW_DYNAMIC_PARAMETER_VALUE_NEEDED'):
+ pass
+
+class FdwInvalidDataType(OperationalError,
+ code='HV004', name='FDW_INVALID_DATA_TYPE'):
+ pass
+
+class FdwColumnNameNotFound(OperationalError,
+ code='HV005', name='FDW_COLUMN_NAME_NOT_FOUND'):
+ pass
+
+class FdwInvalidDataTypeDescriptors(OperationalError,
+ code='HV006', name='FDW_INVALID_DATA_TYPE_DESCRIPTORS'):
+ pass
+
+class FdwInvalidColumnName(OperationalError,
+ code='HV007', name='FDW_INVALID_COLUMN_NAME'):
+ pass
+
+class FdwInvalidColumnNumber(OperationalError,
+ code='HV008', name='FDW_INVALID_COLUMN_NUMBER'):
+ pass
+
+class FdwInvalidUseOfNullPointer(OperationalError,
+ code='HV009', name='FDW_INVALID_USE_OF_NULL_POINTER'):
+ pass
+
+class FdwInvalidStringFormat(OperationalError,
+ code='HV00A', name='FDW_INVALID_STRING_FORMAT'):
+ pass
+
+class FdwInvalidHandle(OperationalError,
+ code='HV00B', name='FDW_INVALID_HANDLE'):
+ pass
+
+class FdwInvalidOptionIndex(OperationalError,
+ code='HV00C', name='FDW_INVALID_OPTION_INDEX'):
+ pass
+
+class FdwInvalidOptionName(OperationalError,
+ code='HV00D', name='FDW_INVALID_OPTION_NAME'):
+ pass
+
+class FdwOptionNameNotFound(OperationalError,
+ code='HV00J', name='FDW_OPTION_NAME_NOT_FOUND'):
+ pass
+
+class FdwReplyHandle(OperationalError,
+ code='HV00K', name='FDW_REPLY_HANDLE'):
+ pass
+
+class FdwUnableToCreateExecution(OperationalError,
+ code='HV00L', name='FDW_UNABLE_TO_CREATE_EXECUTION'):
+ pass
+
+class FdwUnableToCreateReply(OperationalError,
+ code='HV00M', name='FDW_UNABLE_TO_CREATE_REPLY'):
+ pass
+
+class FdwUnableToEstablishConnection(OperationalError,
+ code='HV00N', name='FDW_UNABLE_TO_ESTABLISH_CONNECTION'):
+ pass
+
+class FdwNoSchemas(OperationalError,
+ code='HV00P', name='FDW_NO_SCHEMAS'):
+ pass
+
+class FdwSchemaNotFound(OperationalError,
+ code='HV00Q', name='FDW_SCHEMA_NOT_FOUND'):
+ pass
+
+class FdwTableNotFound(OperationalError,
+ code='HV00R', name='FDW_TABLE_NOT_FOUND'):
+ pass
+
+class FdwFunctionSequenceError(OperationalError,
+ code='HV010', name='FDW_FUNCTION_SEQUENCE_ERROR'):
+ pass
+
+class FdwTooManyHandles(OperationalError,
+ code='HV014', name='FDW_TOO_MANY_HANDLES'):
+ pass
+
+class FdwInconsistentDescriptorInformation(OperationalError,
+ code='HV021', name='FDW_INCONSISTENT_DESCRIPTOR_INFORMATION'):
+ pass
+
+class FdwInvalidAttributeValue(OperationalError,
+ code='HV024', name='FDW_INVALID_ATTRIBUTE_VALUE'):
+ pass
+
+class FdwInvalidStringLengthOrBufferLength(OperationalError,
+ code='HV090', name='FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH'):
+ pass
+
+class FdwInvalidDescriptorFieldIdentifier(OperationalError,
+ code='HV091', name='FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER'):
+ pass
+
+
+# Class P0 - PL/pgSQL Error
+
+class PlpgsqlError(ProgrammingError,
+ code='P0000', name='PLPGSQL_ERROR'):
+ pass
+
+class RaiseException(ProgrammingError,
+ code='P0001', name='RAISE_EXCEPTION'):
+ pass
+
+class NoDataFound(ProgrammingError,
+ code='P0002', name='NO_DATA_FOUND'):
+ pass
+
+class TooManyRows(ProgrammingError,
+ code='P0003', name='TOO_MANY_ROWS'):
+ pass
+
+class AssertFailure(ProgrammingError,
+ code='P0004', name='ASSERT_FAILURE'):
+ pass
+
+
+# Class XX - Internal Error
+
+class InternalError_(InternalError,
+ code='XX000', name='INTERNAL_ERROR'):
+ pass
+
+class DataCorrupted(InternalError,
+ code='XX001', name='DATA_CORRUPTED'):
+ pass
+
+class IndexCorrupted(InternalError,
+ code='XX002', name='INDEX_CORRUPTED'):
+ pass
+
+
+# autogenerated: end
+# fmt: on
diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py
new file mode 100644
index 0000000..584fe47
--- /dev/null
+++ b/psycopg/psycopg/generators.py
@@ -0,0 +1,320 @@
+"""
+Generators implementing communication protocols with the libpq
+
+Certain operations (connection, querying) are an interleave of libpq calls and
+waiting for the socket to be ready. This module contains the code to execute
+the operations, yielding a polling state whenever there is to wait. The
+functions in the `waiting` module are the ones who wait more or less
+cooperatively for the socket to be ready and make these generators continue.
+
+All these generators yield pairs (fileno, `Wait`) whenever an operation would
+block. The generator can be restarted sending the appropriate `Ready` state
+when the file descriptor is ready.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import List, Optional, Union
+
+from . import pq
+from . import errors as e
+from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
+from .pq.abc import PGconn, PGresult
+from .waiting import Wait, Ready
+from ._compat import Deque
+from ._cmodule import _psycopg
+from ._encodings import pgconn_encoding, conninfo_encoding
+
+OK = pq.ConnStatus.OK
+BAD = pq.ConnStatus.BAD
+
+POLL_OK = pq.PollingStatus.OK
+POLL_READING = pq.PollingStatus.READING
+POLL_WRITING = pq.PollingStatus.WRITING
+POLL_FAILED = pq.PollingStatus.FAILED
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
+
+WAIT_R = Wait.R
+WAIT_W = Wait.W
+WAIT_RW = Wait.RW
+READY_R = Ready.R
+READY_W = Ready.W
+READY_RW = Ready.RW
+
+logger = logging.getLogger(__name__)
+
+
+def _connect(conninfo: str) -> PQGenConn[PGconn]:
+ """
+ Generator to create a database connection without blocking.
+
+ """
+ conn = pq.PGconn.connect_start(conninfo.encode())
+ while True:
+ if conn.status == BAD:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection is bad: {pq.error_message(conn, encoding=encoding)}",
+ pgconn=conn,
+ )
+
+ status = conn.connect_poll()
+ if status == POLL_OK:
+ break
+ elif status == POLL_READING:
+ yield conn.socket, WAIT_R
+ elif status == POLL_WRITING:
+ yield conn.socket, WAIT_W
+ elif status == POLL_FAILED:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection failed: {pq.error_message(conn, encoding=encoding)}",
+ pgconn=conn,
+ )
+ else:
+ raise e.InternalError(f"unexpected poll status: {status}", pgconn=conn)
+
+ conn.nonblocking = 1
+ return conn
+
+
+def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
+ """
+ Generator sending a query and returning results without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ yield from _send(pgconn)
+ rv = yield from _fetch_many(pgconn)
+ return rv
+
+
+def _send(pgconn: PGconn) -> PQGen[None]:
+ """
+ Generator to send a query to the server without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ After this generator has finished you may want to cycle using `fetch()`
+ to retrieve the results available.
+ """
+ while True:
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ ready = yield WAIT_RW
+ if ready & READY_R:
+ # This call may read notifies: they will be saved in the
+ # PGconn buffer and passed to Python later, in `fetch()`.
+ pgconn.consume_input()
+
+
+def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
+ """
+ Generator retrieving results from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ results: List[PGresult] = []
+ while True:
+ res = yield from _fetch(pgconn)
+ if not res:
+ break
+
+ results.append(res)
+ status = res.status
+ if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+ # After entering copy mode the libpq will create a phony result
+ # for every request so let's break the endless loop.
+ break
+
+ if status == PIPELINE_SYNC:
+ # PIPELINE_SYNC is not followed by a NULL, but we return it alone
+ # similarly to other result sets.
+ assert len(results) == 1, results
+ break
+
+ return results
+
+
+def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
+ """
+ Generator retrieving a single result from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return a result from the database (whether success or error).
+ """
+ if pgconn.is_busy():
+ yield WAIT_R
+ while True:
+ pgconn.consume_input()
+ if not pgconn.is_busy():
+ break
+ yield WAIT_R
+
+ _consume_notifies(pgconn)
+
+ return pgconn.get_result()
+
+
+def _pipeline_communicate(
+ pgconn: PGconn, commands: Deque[PipelineCommand]
+) -> PQGen[List[List[PGresult]]]:
+ """Generator to send queries from a connection in pipeline mode while also
+ receiving results.
+
+ Return a list results, including single PIPELINE_SYNC elements.
+ """
+ results = []
+
+ while True:
+ ready = yield WAIT_RW
+
+ if ready & READY_R:
+ pgconn.consume_input()
+ _consume_notifies(pgconn)
+
+ res: List[PGresult] = []
+ while not pgconn.is_busy():
+ r = pgconn.get_result()
+ if r is None:
+ if not res:
+ break
+ results.append(res)
+ res = []
+ elif r.status == PIPELINE_SYNC:
+ assert not res
+ results.append([r])
+ else:
+ res.append(r)
+
+ if ready & READY_W:
+ pgconn.flush()
+ if not commands:
+ break
+ commands.popleft()()
+
+ return results
+
+
+def _consume_notifies(pgconn: PGconn) -> None:
+ # Consume notifies
+ while True:
+ n = pgconn.notifies()
+ if not n:
+ break
+ if pgconn.notify_handler:
+ pgconn.notify_handler(n)
+
+
+def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
+ yield WAIT_R
+ pgconn.consume_input()
+
+ ns = []
+ while True:
+ n = pgconn.notifies()
+ if n:
+ ns.append(n)
+ else:
+ break
+
+ return ns
+
+
+def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
+ while True:
+ nbytes, data = pgconn.get_copy_data(1)
+ if nbytes != 0:
+ break
+
+ # would block
+ yield WAIT_R
+ pgconn.consume_input()
+
+ if nbytes > 0:
+ # some data
+ return data
+
+ # Retrieve the final result of copy
+ results = yield from _fetch_many(pgconn)
+ if len(results) > 1:
+ # TODO: too brutal? Copy worked.
+ raise e.ProgrammingError("you cannot mix COPY with other operations")
+ result = results[0]
+ if result.status != COMMAND_OK:
+ encoding = pgconn_encoding(pgconn)
+ raise e.error_from_result(result, encoding=encoding)
+
+ return result
+
+
+def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
+ # Retry enqueuing data until successful.
+ #
+ # WARNING! This can cause an infinite loop if the buffer is too large. (see
+ # ticket #255). We avoid it in the Copy object by splitting a large buffer
+ # into smaller ones. We prefer to do it there instead of here in order to
+ # do it upstream the queue decoupling the writer task from the producer one.
+ while pgconn.put_copy_data(buffer) == 0:
+ yield WAIT_W
+
+
+def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
+ # Retry enqueuing end copy message until successful
+ while pgconn.put_copy_end(error) == 0:
+ yield WAIT_W
+
+ # Repeat until it the message is flushed to the server
+ while True:
+ yield WAIT_W
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ # Retrieve the final result of copy
+ (result,) = yield from _fetch_many(pgconn)
+ if result.status != COMMAND_OK:
+ encoding = pgconn_encoding(pgconn)
+ raise e.error_from_result(result, encoding=encoding)
+
+ return result
+
+
+# Override functions with fast versions if available
+if _psycopg:
+ connect = _psycopg.connect
+ execute = _psycopg.execute
+ send = _psycopg.send
+ fetch_many = _psycopg.fetch_many
+ fetch = _psycopg.fetch
+ pipeline_communicate = _psycopg.pipeline_communicate
+
+else:
+ connect = _connect
+ execute = _execute
+ send = _send
+ fetch_many = _fetch_many
+ fetch = _fetch
+ pipeline_communicate = _pipeline_communicate
diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py
new file mode 100644
index 0000000..792a9c8
--- /dev/null
+++ b/psycopg/psycopg/postgres.py
@@ -0,0 +1,125 @@
+"""
+Types configuration specific to PostgreSQL.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry
+from .abc import AdaptContext
+from ._adapters_map import AdaptersMap
+
+# Global objects with PostgreSQL builtins and globally registered user types.
+types = TypesRegistry()
+
+# Global adapter maps with PostgreSQL types configuration
+adapters = AdaptersMap(types=types)
+
+# Use tools/update_oids.py to update this data.
+for t in [
+ TypeInfo('"char"', 18, 1002),
+ # autogenerated: start
+ # Generated from PostgreSQL 15.0
+ TypeInfo("aclitem", 1033, 1034),
+ TypeInfo("bit", 1560, 1561),
+ TypeInfo("bool", 16, 1000, regtype="boolean"),
+ TypeInfo("box", 603, 1020, delimiter=";"),
+ TypeInfo("bpchar", 1042, 1014, regtype="character"),
+ TypeInfo("bytea", 17, 1001),
+ TypeInfo("cid", 29, 1012),
+ TypeInfo("cidr", 650, 651),
+ TypeInfo("circle", 718, 719),
+ TypeInfo("date", 1082, 1182),
+ TypeInfo("float4", 700, 1021, regtype="real"),
+ TypeInfo("float8", 701, 1022, regtype="double precision"),
+ TypeInfo("gtsvector", 3642, 3644),
+ TypeInfo("inet", 869, 1041),
+ TypeInfo("int2", 21, 1005, regtype="smallint"),
+ TypeInfo("int2vector", 22, 1006),
+ TypeInfo("int4", 23, 1007, regtype="integer"),
+ TypeInfo("int8", 20, 1016, regtype="bigint"),
+ TypeInfo("interval", 1186, 1187),
+ TypeInfo("json", 114, 199),
+ TypeInfo("jsonb", 3802, 3807),
+ TypeInfo("jsonpath", 4072, 4073),
+ TypeInfo("line", 628, 629),
+ TypeInfo("lseg", 601, 1018),
+ TypeInfo("macaddr", 829, 1040),
+ TypeInfo("macaddr8", 774, 775),
+ TypeInfo("money", 790, 791),
+ TypeInfo("name", 19, 1003),
+ TypeInfo("numeric", 1700, 1231),
+ TypeInfo("oid", 26, 1028),
+ TypeInfo("oidvector", 30, 1013),
+ TypeInfo("path", 602, 1019),
+ TypeInfo("pg_lsn", 3220, 3221),
+ TypeInfo("point", 600, 1017),
+ TypeInfo("polygon", 604, 1027),
+ TypeInfo("record", 2249, 2287),
+ TypeInfo("refcursor", 1790, 2201),
+ TypeInfo("regclass", 2205, 2210),
+ TypeInfo("regcollation", 4191, 4192),
+ TypeInfo("regconfig", 3734, 3735),
+ TypeInfo("regdictionary", 3769, 3770),
+ TypeInfo("regnamespace", 4089, 4090),
+ TypeInfo("regoper", 2203, 2208),
+ TypeInfo("regoperator", 2204, 2209),
+ TypeInfo("regproc", 24, 1008),
+ TypeInfo("regprocedure", 2202, 2207),
+ TypeInfo("regrole", 4096, 4097),
+ TypeInfo("regtype", 2206, 2211),
+ TypeInfo("text", 25, 1009),
+ TypeInfo("tid", 27, 1010),
+ TypeInfo("time", 1083, 1183, regtype="time without time zone"),
+ TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
+ TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
+ TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
+ TypeInfo("tsquery", 3615, 3645),
+ TypeInfo("tsvector", 3614, 3643),
+ TypeInfo("txid_snapshot", 2970, 2949),
+ TypeInfo("uuid", 2950, 2951),
+ TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
+ TypeInfo("varchar", 1043, 1015, regtype="character varying"),
+ TypeInfo("xid", 28, 1011),
+ TypeInfo("xid8", 5069, 271),
+ TypeInfo("xml", 142, 143),
+ RangeInfo("daterange", 3912, 3913, subtype_oid=1082),
+ RangeInfo("int4range", 3904, 3905, subtype_oid=23),
+ RangeInfo("int8range", 3926, 3927, subtype_oid=20),
+ RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
+ RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
+ RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
+ MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
+ MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
+ MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
+ MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
+ MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
+ MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
+ # autogenerated: end
+]:
+ types.add(t)
+
+
+# A few oids used a bit everywhere
+INVALID_OID = 0
+TEXT_OID = types["text"].oid
+TEXT_ARRAY_OID = types["text"].array_oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+
+ from .types import array, bool, composite, datetime, enum, json, multirange
+ from .types import net, none, numeric, range, string, uuid
+
+ array.register_default_adapters(context)
+ bool.register_default_adapters(context)
+ composite.register_default_adapters(context)
+ datetime.register_default_adapters(context)
+ enum.register_default_adapters(context)
+ json.register_default_adapters(context)
+ multirange.register_default_adapters(context)
+ net.register_default_adapters(context)
+ none.register_default_adapters(context)
+ numeric.register_default_adapters(context)
+ range.register_default_adapters(context)
+ string.register_default_adapters(context)
+ uuid.register_default_adapters(context)
diff --git a/psycopg/psycopg/pq/__init__.py b/psycopg/psycopg/pq/__init__.py
new file mode 100644
index 0000000..d5180b1
--- /dev/null
+++ b/psycopg/psycopg/pq/__init__.py
@@ -0,0 +1,133 @@
+"""
+psycopg libpq wrapper
+
+This package exposes the libpq functionalities as Python objects and functions.
+
+The real implementation (the binding to the C library) is
+implementation-dependant but all the implementations share the same interface.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import logging
+from typing import Callable, List, Type
+
+from . import abc
+from .misc import ConninfoOption, PGnotify, PGresAttDesc
+from .misc import error_message
+from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace
+from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus
+
+logger = logging.getLogger(__name__)
+
+__impl__: str
+"""The currently loaded implementation of the `!psycopg.pq` package.
+
+Possible values include ``python``, ``c``, ``binary``.
+"""
+
+__build_version__: int
+"""The libpq version the C package was built with.
+
+A number in the same format of `~psycopg.ConnectionInfo.server_version`
+representing the libpq used to build the speedup module (``c``, ``binary``) if
+available.
+
+Certain features might not be available if the built version is too old.
+"""
+
+version: Callable[[], int]
+PGconn: Type[abc.PGconn]
+PGresult: Type[abc.PGresult]
+Conninfo: Type[abc.Conninfo]
+Escaping: Type[abc.Escaping]
+PGcancel: Type[abc.PGcancel]
+
+
+def import_from_libpq() -> None:
+ """
+ Import pq objects implementation from the best libpq wrapper available.
+
+ If an implementation is requested try to import only it, otherwise
+ try to import the best implementation available.
+ """
+ # import these names into the module on success as side effect
+ global __impl__, version, __build_version__
+ global PGconn, PGresult, Conninfo, Escaping, PGcancel
+
+ impl = os.environ.get("PSYCOPG_IMPL", "").lower()
+ module = None
+ attempts: List[str] = []
+
+ def handle_error(name: str, e: Exception) -> None:
+ if not impl:
+ msg = f"couldn't import psycopg '{name}' implementation: {e}"
+ logger.debug(msg)
+ attempts.append(msg)
+ else:
+ msg = f"couldn't import requested psycopg '{name}' implementation: {e}"
+ raise ImportError(msg) from e
+
+ # The best implementation: fast but requires the system libpq installed
+ if not impl or impl == "c":
+ try:
+ from psycopg_c import pq as module # type: ignore
+ except Exception as e:
+ handle_error("c", e)
+
+ # Second best implementation: fast and stand-alone
+ if not module and (not impl or impl == "binary"):
+ try:
+ from psycopg_binary import pq as module # type: ignore
+ except Exception as e:
+ handle_error("binary", e)
+
+ # Pure Python implementation, slow and requires the system libpq installed.
+ if not module and (not impl or impl == "python"):
+ try:
+ from . import pq_ctypes as module # type: ignore[no-redef]
+ except Exception as e:
+ handle_error("python", e)
+
+ if module:
+ __impl__ = module.__impl__
+ version = module.version
+ PGconn = module.PGconn
+ PGresult = module.PGresult
+ Conninfo = module.Conninfo
+ Escaping = module.Escaping
+ PGcancel = module.PGcancel
+ __build_version__ = module.__build_version__
+ elif impl:
+ raise ImportError(f"requested psycopg implementation '{impl}' unknown")
+ else:
+ sattempts = "\n".join(f"- {attempt}" for attempt in attempts)
+ raise ImportError(
+ f"""\
+no pq wrapper available.
+Attempts made:
+{sattempts}"""
+ )
+
+
+import_from_libpq()
+
+__all__ = (
+ "ConnStatus",
+ "PipelineStatus",
+ "PollingStatus",
+ "TransactionStatus",
+ "ExecStatus",
+ "Ping",
+ "DiagnosticField",
+ "Format",
+ "Trace",
+ "PGconn",
+ "PGnotify",
+ "Conninfo",
+ "PGresAttDesc",
+ "error_message",
+ "ConninfoOption",
+ "version",
+)
diff --git a/psycopg/psycopg/pq/_debug.py b/psycopg/psycopg/pq/_debug.py
new file mode 100644
index 0000000..f35d09f
--- /dev/null
+++ b/psycopg/psycopg/pq/_debug.py
@@ -0,0 +1,106 @@
+"""
+libpq debugging tools
+
+These functionalities are exposed here for convenience, but are not part of
+the public interface and are subject to change at any moment.
+
+Suggested usage::
+
+ import logging
+ import psycopg
+ from psycopg import pq
+ from psycopg.pq._debug import PGconnDebug
+
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
+ logger = logging.getLogger("psycopg.debug")
+ logger.setLevel(logging.INFO)
+
+ assert pq.__impl__ == "python"
+ pq.PGconn = PGconnDebug
+
+ with psycopg.connect("") as conn:
+ conn.pgconn.trace(2)
+ conn.pgconn.set_trace_flags(
+ pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ ...
+
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import inspect
+import logging
+from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING
+from functools import wraps
+
+from . import PGconn
+from .misc import connection_summary
+
+if TYPE_CHECKING:
+ from . import abc
+
+Func = TypeVar("Func", bound=Callable[..., Any])
+
+logger = logging.getLogger("psycopg.debug")
+
+
+class PGconnDebug:
+ """Wrapper for a PQconn logging all its access."""
+
+ _Self = TypeVar("_Self", bound="PGconnDebug")
+ _pgconn: "abc.PGconn"
+
+ def __init__(self, pgconn: "abc.PGconn"):
+ super().__setattr__("_pgconn", pgconn)
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self._pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ def __getattr__(self, attr: str) -> Any:
+ value = getattr(self._pgconn, attr)
+ if callable(value):
+ return debugging(value)
+ else:
+ logger.info("PGconn.%s -> %s", attr, value)
+ return value
+
+ def __setattr__(self, attr: str, value: Any) -> None:
+ setattr(self._pgconn, attr, value)
+ logger.info("PGconn.%s <- %s", attr, value)
+
+ @classmethod
+ def connect(cls: Type[_Self], conninfo: bytes) -> _Self:
+ return cls(debugging(PGconn.connect)(conninfo))
+
+ @classmethod
+ def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self:
+ return cls(debugging(PGconn.connect_start)(conninfo))
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ return debugging(PGconn.ping)(conninfo)
+
+
+def debugging(f: Func) -> Func:
+ """Wrap a function in order to log its arguments and return value on call."""
+
+ @wraps(f)
+ def debugging_(*args: Any, **kwargs: Any) -> Any:
+ reprs = []
+ for arg in args:
+ reprs.append(f"{arg!r}")
+ for (k, v) in kwargs.items():
+ reprs.append(f"{k}={v!r}")
+
+ logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs))
+ rv = f(*args, **kwargs)
+ # Display the return value only if the function is declared to return
+ # something else than None.
+ ra = inspect.signature(f).return_annotation
+ if ra is not None or rv is not None:
+ logger.info(" <- %r", rv)
+ return rv
+
+ return debugging_ # type: ignore
diff --git a/psycopg/psycopg/pq/_enums.py b/psycopg/psycopg/pq/_enums.py
new file mode 100644
index 0000000..e0d4018
--- /dev/null
+++ b/psycopg/psycopg/pq/_enums.py
@@ -0,0 +1,249 @@
+"""
+libpq enum definitions for psycopg
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import IntEnum, IntFlag, auto
+
+
+class ConnStatus(IntEnum):
+ """
+ Current status of the connection.
+ """
+
+ __module__ = "psycopg.pq"
+
+ OK = 0
+ """The connection is in a working state."""
+ BAD = auto()
+ """The connection is closed."""
+
+ STARTED = auto()
+ MADE = auto()
+ AWAITING_RESPONSE = auto()
+ AUTH_OK = auto()
+ SETENV = auto()
+ SSL_STARTUP = auto()
+ NEEDED = auto()
+ CHECK_WRITABLE = auto()
+ CONSUME = auto()
+ GSS_STARTUP = auto()
+ CHECK_TARGET = auto()
+ CHECK_STANDBY = auto()
+
+
+class PollingStatus(IntEnum):
+ """
+ The status of the socket during a connection.
+
+ If ``READING`` or ``WRITING`` you may select before polling again.
+ """
+
+ __module__ = "psycopg.pq"
+
+ FAILED = 0
+ """Connection attempt failed."""
+ READING = auto()
+ """Will have to wait before reading new data."""
+ WRITING = auto()
+ """Will have to wait before writing new data."""
+ OK = auto()
+ """Connection completed."""
+
+ ACTIVE = auto()
+
+
+class ExecStatus(IntEnum):
+ """
+ The status of a command.
+ """
+
+ __module__ = "psycopg.pq"
+
+ EMPTY_QUERY = 0
+ """The string sent to the server was empty."""
+
+ COMMAND_OK = auto()
+ """Successful completion of a command returning no data."""
+
+ TUPLES_OK = auto()
+ """
+ Successful completion of a command returning data (such as a SELECT or SHOW).
+ """
+
+ COPY_OUT = auto()
+ """Copy Out (from server) data transfer started."""
+
+ COPY_IN = auto()
+ """Copy In (to server) data transfer started."""
+
+ BAD_RESPONSE = auto()
+ """The server's response was not understood."""
+
+ NONFATAL_ERROR = auto()
+ """A nonfatal error (a notice or warning) occurred."""
+
+ FATAL_ERROR = auto()
+ """A fatal error occurred."""
+
+ COPY_BOTH = auto()
+ """
+ Copy In/Out (to and from server) data transfer started.
+
+ This feature is currently used only for streaming replication, so this
+ status should not occur in ordinary applications.
+ """
+
+ SINGLE_TUPLE = auto()
+ """
+ The PGresult contains a single result tuple from the current command.
+
+ This status occurs only when single-row mode has been selected for the
+ query.
+ """
+
+ PIPELINE_SYNC = auto()
+ """
+ The PGresult represents a synchronization point in pipeline mode,
+ requested by PQpipelineSync.
+
+ This status occurs only when pipeline mode has been selected.
+ """
+
+ PIPELINE_ABORTED = auto()
+ """
+ The PGresult represents a pipeline that has received an error from the server.
+
+ PQgetResult must be called repeatedly, and each time it will return this
+ status code until the end of the current pipeline, at which point it will
+ return PGRES_PIPELINE_SYNC and normal processing can resume.
+ """
+
+
+class TransactionStatus(IntEnum):
+ """
+ The transaction status of a connection.
+ """
+
+ __module__ = "psycopg.pq"
+
+ IDLE = 0
+ """Connection ready, no transaction active."""
+
+ ACTIVE = auto()
+ """A command is in progress."""
+
+ INTRANS = auto()
+ """Connection idle in an open transaction."""
+
+ INERROR = auto()
+ """An error happened in the current transaction."""
+
+ UNKNOWN = auto()
+ """Unknown connection state, broken connection."""
+
+
+class Ping(IntEnum):
+ """Response from a ping attempt."""
+
+ __module__ = "psycopg.pq"
+
+ OK = 0
+ """
+ The server is running and appears to be accepting connections.
+ """
+
+ REJECT = auto()
+ """
+ The server is running but is in a state that disallows connections.
+ """
+
+ NO_RESPONSE = auto()
+ """
+ The server could not be contacted.
+ """
+
+ NO_ATTEMPT = auto()
+ """
+ No attempt was made to contact the server.
+ """
+
+
+class PipelineStatus(IntEnum):
+ """Pipeline mode status of the libpq connection."""
+
+ __module__ = "psycopg.pq"
+
+ OFF = 0
+ """
+ The libpq connection is *not* in pipeline mode.
+ """
+ ON = auto()
+ """
+ The libpq connection is in pipeline mode.
+ """
+ ABORTED = auto()
+ """
+ The libpq connection is in pipeline mode and an error occurred while
+ processing the current pipeline. The aborted flag is cleared when
+ PQgetResult returns a result of type PGRES_PIPELINE_SYNC.
+ """
+
+
+class DiagnosticField(IntEnum):
+ """
+ Fields in an error report.
+ """
+
+ __module__ = "psycopg.pq"
+
+ # from postgres_ext.h
+ SEVERITY = ord("S")
+ SEVERITY_NONLOCALIZED = ord("V")
+ SQLSTATE = ord("C")
+ MESSAGE_PRIMARY = ord("M")
+ MESSAGE_DETAIL = ord("D")
+ MESSAGE_HINT = ord("H")
+ STATEMENT_POSITION = ord("P")
+ INTERNAL_POSITION = ord("p")
+ INTERNAL_QUERY = ord("q")
+ CONTEXT = ord("W")
+ SCHEMA_NAME = ord("s")
+ TABLE_NAME = ord("t")
+ COLUMN_NAME = ord("c")
+ DATATYPE_NAME = ord("d")
+ CONSTRAINT_NAME = ord("n")
+ SOURCE_FILE = ord("F")
+ SOURCE_LINE = ord("L")
+ SOURCE_FUNCTION = ord("R")
+
+
+class Format(IntEnum):
+ """
+ Enum representing the format of a query argument or return value.
+
+ These values are only the ones managed by the libpq. `~psycopg` may also
+ support automatically-chosen values: see `psycopg.adapt.PyFormat`.
+ """
+
+ __module__ = "psycopg.pq"
+
+ TEXT = 0
+ """Text parameter."""
+ BINARY = 1
+ """Binary parameter."""
+
+
+class Trace(IntFlag):
+ """
+ Enum to control tracing of the client/server communication.
+ """
+
+ __module__ = "psycopg.pq"
+
+ SUPPRESS_TIMESTAMPS = 1
+ """Do not include timestamps in messages."""
+
+ REGRESS_MODE = 2
+ """Redact some fields, e.g. OIDs, from messages."""
diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py
new file mode 100644
index 0000000..9ca1d12
--- /dev/null
+++ b/psycopg/psycopg/pq/_pq_ctypes.py
@@ -0,0 +1,804 @@
+"""
+libpq access using ctypes
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import ctypes
+import ctypes.util
+from ctypes import Structure, CFUNCTYPE, POINTER
+from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
+from typing import List, Optional, Tuple
+
+from .misc import find_libpq_full_path
+from ..errors import NotSupportedError
+
+libname = find_libpq_full_path()
+if not libname:
+ raise ImportError("libpq library not found")
+
+pq = ctypes.cdll.LoadLibrary(libname)
+
+
+class FILE(Structure):
+ pass
+
+
+FILE_ptr = POINTER(FILE)
+
+if sys.platform == "linux":
+ libcname = ctypes.util.find_library("c")
+ assert libcname
+ libc = ctypes.cdll.LoadLibrary(libcname)
+
+ fdopen = libc.fdopen
+ fdopen.argtypes = (c_int, c_char_p)
+ fdopen.restype = FILE_ptr
+
+
+# Get the libpq version to define what functions are available.
+
+PQlibVersion = pq.PQlibVersion
+PQlibVersion.argtypes = []
+PQlibVersion.restype = c_int
+
+libpq_version = PQlibVersion()
+
+
+# libpq data types
+
+
+Oid = c_uint
+
+
+class PGconn_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PGresult_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PQconninfoOption_struct(Structure):
+ _fields_ = [
+ ("keyword", c_char_p),
+ ("envvar", c_char_p),
+ ("compiled", c_char_p),
+ ("val", c_char_p),
+ ("label", c_char_p),
+ ("dispchar", c_char_p),
+ ("dispsize", c_int),
+ ]
+
+
+class PGnotify_struct(Structure):
+ _fields_ = [
+ ("relname", c_char_p),
+ ("be_pid", c_int),
+ ("extra", c_char_p),
+ ]
+
+
+class PGcancel_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PGresAttDesc_struct(Structure):
+ _fields_ = [
+ ("name", c_char_p),
+ ("tableid", Oid),
+ ("columnid", c_int),
+ ("format", c_int),
+ ("typid", Oid),
+ ("typlen", c_int),
+ ("atttypmod", c_int),
+ ]
+
+
+PGconn_ptr = POINTER(PGconn_struct)
+PGresult_ptr = POINTER(PGresult_struct)
+PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
+PGnotify_ptr = POINTER(PGnotify_struct)
+PGcancel_ptr = POINTER(PGcancel_struct)
+PGresAttDesc_ptr = POINTER(PGresAttDesc_struct)
+
+
+# Function definitions as explained in PostgreSQL 12 documentation
+
+# 33.1. Database Connection Control Functions
+
+# PQconnectdbParams: doesn't seem useful, won't wrap for now
+
+PQconnectdb = pq.PQconnectdb
+PQconnectdb.argtypes = [c_char_p]
+PQconnectdb.restype = PGconn_ptr
+
+# PQsetdbLogin: not useful
+# PQsetdb: not useful
+
+# PQconnectStartParams: not useful
+
+PQconnectStart = pq.PQconnectStart
+PQconnectStart.argtypes = [c_char_p]
+PQconnectStart.restype = PGconn_ptr
+
+PQconnectPoll = pq.PQconnectPoll
+PQconnectPoll.argtypes = [PGconn_ptr]
+PQconnectPoll.restype = c_int
+
+PQconndefaults = pq.PQconndefaults
+PQconndefaults.argtypes = []
+PQconndefaults.restype = PQconninfoOption_ptr
+
+PQconninfoFree = pq.PQconninfoFree
+PQconninfoFree.argtypes = [PQconninfoOption_ptr]
+PQconninfoFree.restype = None
+
+PQconninfo = pq.PQconninfo
+PQconninfo.argtypes = [PGconn_ptr]
+PQconninfo.restype = PQconninfoOption_ptr
+
+PQconninfoParse = pq.PQconninfoParse
+PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)]
+PQconninfoParse.restype = PQconninfoOption_ptr
+
+PQfinish = pq.PQfinish
+PQfinish.argtypes = [PGconn_ptr]
+PQfinish.restype = None
+
+PQreset = pq.PQreset
+PQreset.argtypes = [PGconn_ptr]
+PQreset.restype = None
+
+PQresetStart = pq.PQresetStart
+PQresetStart.argtypes = [PGconn_ptr]
+PQresetStart.restype = c_int
+
+PQresetPoll = pq.PQresetPoll
+PQresetPoll.argtypes = [PGconn_ptr]
+PQresetPoll.restype = c_int
+
+PQping = pq.PQping
+PQping.argtypes = [c_char_p]
+PQping.restype = c_int
+
+
+# 33.2. Connection Status Functions
+
+PQdb = pq.PQdb
+PQdb.argtypes = [PGconn_ptr]
+PQdb.restype = c_char_p
+
+PQuser = pq.PQuser
+PQuser.argtypes = [PGconn_ptr]
+PQuser.restype = c_char_p
+
+PQpass = pq.PQpass
+PQpass.argtypes = [PGconn_ptr]
+PQpass.restype = c_char_p
+
+PQhost = pq.PQhost
+PQhost.argtypes = [PGconn_ptr]
+PQhost.restype = c_char_p
+
+_PQhostaddr = None
+
+if libpq_version >= 120000:
+ _PQhostaddr = pq.PQhostaddr
+ _PQhostaddr.argtypes = [PGconn_ptr]
+ _PQhostaddr.restype = c_char_p
+
+
+def PQhostaddr(pgconn: PGconn_struct) -> bytes:
+ if not _PQhostaddr:
+ raise NotSupportedError(
+ "PQhostaddr requires libpq from PostgreSQL 12,"
+ f" {libpq_version} available instead"
+ )
+
+ return _PQhostaddr(pgconn)
+
+
+PQport = pq.PQport
+PQport.argtypes = [PGconn_ptr]
+PQport.restype = c_char_p
+
+PQtty = pq.PQtty
+PQtty.argtypes = [PGconn_ptr]
+PQtty.restype = c_char_p
+
+PQoptions = pq.PQoptions
+PQoptions.argtypes = [PGconn_ptr]
+PQoptions.restype = c_char_p
+
+PQstatus = pq.PQstatus
+PQstatus.argtypes = [PGconn_ptr]
+PQstatus.restype = c_int
+
+PQtransactionStatus = pq.PQtransactionStatus
+PQtransactionStatus.argtypes = [PGconn_ptr]
+PQtransactionStatus.restype = c_int
+
+PQparameterStatus = pq.PQparameterStatus
+PQparameterStatus.argtypes = [PGconn_ptr, c_char_p]
+PQparameterStatus.restype = c_char_p
+
+PQprotocolVersion = pq.PQprotocolVersion
+PQprotocolVersion.argtypes = [PGconn_ptr]
+PQprotocolVersion.restype = c_int
+
+PQserverVersion = pq.PQserverVersion
+PQserverVersion.argtypes = [PGconn_ptr]
+PQserverVersion.restype = c_int
+
+PQerrorMessage = pq.PQerrorMessage
+PQerrorMessage.argtypes = [PGconn_ptr]
+PQerrorMessage.restype = c_char_p
+
+PQsocket = pq.PQsocket
+PQsocket.argtypes = [PGconn_ptr]
+PQsocket.restype = c_int
+
+PQbackendPID = pq.PQbackendPID
+PQbackendPID.argtypes = [PGconn_ptr]
+PQbackendPID.restype = c_int
+
+PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword
+PQconnectionNeedsPassword.argtypes = [PGconn_ptr]
+PQconnectionNeedsPassword.restype = c_int
+
+PQconnectionUsedPassword = pq.PQconnectionUsedPassword
+PQconnectionUsedPassword.argtypes = [PGconn_ptr]
+PQconnectionUsedPassword.restype = c_int
+
+PQsslInUse = pq.PQsslInUse
+PQsslInUse.argtypes = [PGconn_ptr]
+PQsslInUse.restype = c_int
+
+# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
+
+
+# 33.3. Command Execution Functions
+
+PQexec = pq.PQexec
+PQexec.argtypes = [PGconn_ptr, c_char_p]
+PQexec.restype = PGresult_ptr
+
+PQexecParams = pq.PQexecParams
+PQexecParams.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(Oid),
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQexecParams.restype = PGresult_ptr
+
+PQprepare = pq.PQprepare
+PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
+PQprepare.restype = PGresult_ptr
+
+PQexecPrepared = pq.PQexecPrepared
+PQexecPrepared.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQexecPrepared.restype = PGresult_ptr
+
+PQdescribePrepared = pq.PQdescribePrepared
+PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p]
+PQdescribePrepared.restype = PGresult_ptr
+
+PQdescribePortal = pq.PQdescribePortal
+PQdescribePortal.argtypes = [PGconn_ptr, c_char_p]
+PQdescribePortal.restype = PGresult_ptr
+
+PQresultStatus = pq.PQresultStatus
+PQresultStatus.argtypes = [PGresult_ptr]
+PQresultStatus.restype = c_int
+
+# PQresStatus: not needed, we have pretty enums
+
+PQresultErrorMessage = pq.PQresultErrorMessage
+PQresultErrorMessage.argtypes = [PGresult_ptr]
+PQresultErrorMessage.restype = c_char_p
+
+# TODO: PQresultVerboseErrorMessage
+
+PQresultErrorField = pq.PQresultErrorField
+PQresultErrorField.argtypes = [PGresult_ptr, c_int]
+PQresultErrorField.restype = c_char_p
+
+PQclear = pq.PQclear
+PQclear.argtypes = [PGresult_ptr]
+PQclear.restype = None
+
+
+# 33.3.2. Retrieving Query Result Information
+
+PQntuples = pq.PQntuples
+PQntuples.argtypes = [PGresult_ptr]
+PQntuples.restype = c_int
+
+PQnfields = pq.PQnfields
+PQnfields.argtypes = [PGresult_ptr]
+PQnfields.restype = c_int
+
+PQfname = pq.PQfname
+PQfname.argtypes = [PGresult_ptr, c_int]
+PQfname.restype = c_char_p
+
+# PQfnumber: useless and hard to use
+
+PQftable = pq.PQftable
+PQftable.argtypes = [PGresult_ptr, c_int]
+PQftable.restype = Oid
+
+PQftablecol = pq.PQftablecol
+PQftablecol.argtypes = [PGresult_ptr, c_int]
+PQftablecol.restype = c_int
+
+PQfformat = pq.PQfformat
+PQfformat.argtypes = [PGresult_ptr, c_int]
+PQfformat.restype = c_int
+
+PQftype = pq.PQftype
+PQftype.argtypes = [PGresult_ptr, c_int]
+PQftype.restype = Oid
+
+PQfmod = pq.PQfmod
+PQfmod.argtypes = [PGresult_ptr, c_int]
+PQfmod.restype = c_int
+
+PQfsize = pq.PQfsize
+PQfsize.argtypes = [PGresult_ptr, c_int]
+PQfsize.restype = c_int
+
+PQbinaryTuples = pq.PQbinaryTuples
+PQbinaryTuples.argtypes = [PGresult_ptr]
+PQbinaryTuples.restype = c_int
+
+PQgetvalue = pq.PQgetvalue
+PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetvalue.restype = POINTER(c_char) # not a null-terminated string
+
+PQgetisnull = pq.PQgetisnull
+PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetisnull.restype = c_int
+
+PQgetlength = pq.PQgetlength
+PQgetlength.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetlength.restype = c_int
+
+PQnparams = pq.PQnparams
+PQnparams.argtypes = [PGresult_ptr]
+PQnparams.restype = c_int
+
+PQparamtype = pq.PQparamtype
+PQparamtype.argtypes = [PGresult_ptr, c_int]
+PQparamtype.restype = Oid
+
+# PQprint: pretty useless
+
+# 33.3.3. Retrieving Other Result Information
+
+PQcmdStatus = pq.PQcmdStatus
+PQcmdStatus.argtypes = [PGresult_ptr]
+PQcmdStatus.restype = c_char_p
+
+PQcmdTuples = pq.PQcmdTuples
+PQcmdTuples.argtypes = [PGresult_ptr]
+PQcmdTuples.restype = c_char_p
+
+PQoidValue = pq.PQoidValue
+PQoidValue.argtypes = [PGresult_ptr]
+PQoidValue.restype = Oid
+
+
+# 33.3.4. Escaping Strings for Inclusion in SQL Commands
+
+PQescapeLiteral = pq.PQescapeLiteral
+PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeLiteral.restype = POINTER(c_char)
+
+PQescapeIdentifier = pq.PQescapeIdentifier
+PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeIdentifier.restype = POINTER(c_char)
+
+PQescapeStringConn = pq.PQescapeStringConn
+# TODO: raises "wrong type" error
+# PQescapeStringConn.argtypes = [
+# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
+# ]
+PQescapeStringConn.restype = c_size_t
+
+PQescapeString = pq.PQescapeString
+# TODO: raises "wrong type" error
+# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t]
+PQescapeString.restype = c_size_t
+
+PQescapeByteaConn = pq.PQescapeByteaConn
+PQescapeByteaConn.argtypes = [
+ PGconn_ptr,
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ c_size_t,
+ POINTER(c_size_t),
+]
+PQescapeByteaConn.restype = POINTER(c_ubyte)
+
+PQescapeBytea = pq.PQescapeBytea
+PQescapeBytea.argtypes = [
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ c_size_t,
+ POINTER(c_size_t),
+]
+PQescapeBytea.restype = POINTER(c_ubyte)
+
+
+PQunescapeBytea = pq.PQunescapeBytea
+PQunescapeBytea.argtypes = [
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ POINTER(c_size_t),
+]
+PQunescapeBytea.restype = POINTER(c_ubyte)
+
+
+# 33.4. Asynchronous Command Processing
+
+PQsendQuery = pq.PQsendQuery
+PQsendQuery.argtypes = [PGconn_ptr, c_char_p]
+PQsendQuery.restype = c_int
+
+PQsendQueryParams = pq.PQsendQueryParams
+PQsendQueryParams.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(Oid),
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQsendQueryParams.restype = c_int
+
+PQsendPrepare = pq.PQsendPrepare
+PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
+PQsendPrepare.restype = c_int
+
+PQsendQueryPrepared = pq.PQsendQueryPrepared
+PQsendQueryPrepared.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQsendQueryPrepared.restype = c_int
+
+PQsendDescribePrepared = pq.PQsendDescribePrepared
+PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p]
+PQsendDescribePrepared.restype = c_int
+
+PQsendDescribePortal = pq.PQsendDescribePortal
+PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p]
+PQsendDescribePortal.restype = c_int
+
+PQgetResult = pq.PQgetResult
+PQgetResult.argtypes = [PGconn_ptr]
+PQgetResult.restype = PGresult_ptr
+
+PQconsumeInput = pq.PQconsumeInput
+PQconsumeInput.argtypes = [PGconn_ptr]
+PQconsumeInput.restype = c_int
+
+PQisBusy = pq.PQisBusy
+PQisBusy.argtypes = [PGconn_ptr]
+PQisBusy.restype = c_int
+
+PQsetnonblocking = pq.PQsetnonblocking
+PQsetnonblocking.argtypes = [PGconn_ptr, c_int]
+PQsetnonblocking.restype = c_int
+
+PQisnonblocking = pq.PQisnonblocking
+PQisnonblocking.argtypes = [PGconn_ptr]
+PQisnonblocking.restype = c_int
+
+PQflush = pq.PQflush
+PQflush.argtypes = [PGconn_ptr]
+PQflush.restype = c_int
+
+
+# 33.5. Retrieving Query Results Row-by-Row
+PQsetSingleRowMode = pq.PQsetSingleRowMode
+PQsetSingleRowMode.argtypes = [PGconn_ptr]
+PQsetSingleRowMode.restype = c_int
+
+
+# 33.6. Canceling Queries in Progress
+
+PQgetCancel = pq.PQgetCancel
+PQgetCancel.argtypes = [PGconn_ptr]
+PQgetCancel.restype = PGcancel_ptr
+
+PQfreeCancel = pq.PQfreeCancel
+PQfreeCancel.argtypes = [PGcancel_ptr]
+PQfreeCancel.restype = None
+
+PQcancel = pq.PQcancel
+# TODO: raises "wrong type" error
+# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
+PQcancel.restype = c_int
+
+
+# 33.8. Asynchronous Notification
+
+PQnotifies = pq.PQnotifies
+PQnotifies.argtypes = [PGconn_ptr]
+PQnotifies.restype = PGnotify_ptr
+
+
+# 33.9. Functions Associated with the COPY Command
+
+PQputCopyData = pq.PQputCopyData
+PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
+PQputCopyData.restype = c_int
+
+PQputCopyEnd = pq.PQputCopyEnd
+PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
+PQputCopyEnd.restype = c_int
+
+PQgetCopyData = pq.PQgetCopyData
+PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
+PQgetCopyData.restype = c_int
+
+
+# 33.10. Control Functions
+
+PQtrace = pq.PQtrace
+PQtrace.argtypes = [PGconn_ptr, FILE_ptr]
+PQtrace.restype = None
+
+_PQsetTraceFlags = None
+
+if libpq_version >= 140000:
+ _PQsetTraceFlags = pq.PQsetTraceFlags
+ _PQsetTraceFlags.argtypes = [PGconn_ptr, c_int]
+ _PQsetTraceFlags.restype = None
+
+
+def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None:
+ if not _PQsetTraceFlags:
+ raise NotSupportedError(
+ "PQsetTraceFlags requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+
+ _PQsetTraceFlags(pgconn, flags)
+
+
+PQuntrace = pq.PQuntrace
+PQuntrace.argtypes = [PGconn_ptr]
+PQuntrace.restype = None
+
+# 33.11. Miscellaneous Functions
+
+PQfreemem = pq.PQfreemem
+PQfreemem.argtypes = [c_void_p]
+PQfreemem.restype = None
+
+if libpq_version >= 100000:
+ _PQencryptPasswordConn = pq.PQencryptPasswordConn
+ _PQencryptPasswordConn.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_char_p,
+ c_char_p,
+ ]
+ _PQencryptPasswordConn.restype = POINTER(c_char)
+
+
+def PQencryptPasswordConn(
+ pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes
+) -> Optional[bytes]:
+ if not _PQencryptPasswordConn:
+ raise NotSupportedError(
+ "PQencryptPasswordConn requires libpq from PostgreSQL 10,"
+ f" {libpq_version} available instead"
+ )
+
+ return _PQencryptPasswordConn(pgconn, passwd, user, algorithm)
+
+
+PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult
+PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
+PQmakeEmptyPGresult.restype = PGresult_ptr
+
+PQsetResultAttrs = pq.PQsetResultAttrs
+PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr]
+PQsetResultAttrs.restype = c_int
+
+
+# 33.12. Notice Processing
+
+PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr)
+
+PQsetNoticeReceiver = pq.PQsetNoticeReceiver
+PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p]
+PQsetNoticeReceiver.restype = PQnoticeReceiver
+
+# 34.5 Pipeline Mode
+
+_PQpipelineStatus = None
+_PQenterPipelineMode = None
+_PQexitPipelineMode = None
+_PQpipelineSync = None
+_PQsendFlushRequest = None
+
+if libpq_version >= 140000:
+ _PQpipelineStatus = pq.PQpipelineStatus
+ _PQpipelineStatus.argtypes = [PGconn_ptr]
+ _PQpipelineStatus.restype = c_int
+
+ _PQenterPipelineMode = pq.PQenterPipelineMode
+ _PQenterPipelineMode.argtypes = [PGconn_ptr]
+ _PQenterPipelineMode.restype = c_int
+
+ _PQexitPipelineMode = pq.PQexitPipelineMode
+ _PQexitPipelineMode.argtypes = [PGconn_ptr]
+ _PQexitPipelineMode.restype = c_int
+
+ _PQpipelineSync = pq.PQpipelineSync
+ _PQpipelineSync.argtypes = [PGconn_ptr]
+ _PQpipelineSync.restype = c_int
+
+ _PQsendFlushRequest = pq.PQsendFlushRequest
+ _PQsendFlushRequest.argtypes = [PGconn_ptr]
+ _PQsendFlushRequest.restype = c_int
+
+
+def PQpipelineStatus(pgconn: PGconn_struct) -> int:
+ if not _PQpipelineStatus:
+ raise NotSupportedError(
+ "PQpipelineStatus requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQpipelineStatus(pgconn)
+
+
+def PQenterPipelineMode(pgconn: PGconn_struct) -> int:
+ if not _PQenterPipelineMode:
+ raise NotSupportedError(
+ "PQenterPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQenterPipelineMode(pgconn)
+
+
+def PQexitPipelineMode(pgconn: PGconn_struct) -> int:
+ if not _PQexitPipelineMode:
+ raise NotSupportedError(
+ "PQexitPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQexitPipelineMode(pgconn)
+
+
+def PQpipelineSync(pgconn: PGconn_struct) -> int:
+ if not _PQpipelineSync:
+ raise NotSupportedError(
+ "PQpipelineSync requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQpipelineSync(pgconn)
+
+
+def PQsendFlushRequest(pgconn: PGconn_struct) -> int:
+ if not _PQsendFlushRequest:
+ raise NotSupportedError(
+ "PQsendFlushRequest requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQsendFlushRequest(pgconn)
+
+
+# 33.18. SSL Support
+
+PQinitOpenSSL = pq.PQinitOpenSSL
+PQinitOpenSSL.argtypes = [c_int, c_int]
+PQinitOpenSSL.restype = None
+
+
+def generate_stub() -> None:
+ import re
+ from ctypes import _CFuncPtr # type: ignore
+
+ def type2str(fname, narg, t):
+ if t is None:
+ return "None"
+ elif t is c_void_p:
+ return "Any"
+ elif t is c_int or t is c_uint or t is c_size_t:
+ return "int"
+ elif t is c_char_p or t.__name__ == "LP_c_char":
+ if narg is not None:
+ return "bytes"
+ else:
+ return "Optional[bytes]"
+
+ elif t.__name__ in (
+ "LP_PGconn_struct",
+ "LP_PGresult_struct",
+ "LP_PGcancel_struct",
+ ):
+ if narg is not None:
+ return f"Optional[{t.__name__[3:]}]"
+ else:
+ return t.__name__[3:]
+
+ elif t.__name__ in ("LP_PQconninfoOption_struct",):
+ return f"Sequence[{t.__name__[3:]}]"
+
+ elif t.__name__ in (
+ "LP_c_ubyte",
+ "LP_c_char_p",
+ "LP_c_int",
+ "LP_c_uint",
+ "LP_c_ulong",
+ "LP_FILE",
+ ):
+ return f"_Pointer[{t.__name__[3:]}]"
+
+ else:
+ assert False, f"can't deal with {t} in {fname}"
+
+ fn = __file__ + "i"
+ with open(fn) as f:
+ lines = f.read().splitlines()
+
+ istart, iend = (
+ i
+ for i, line in enumerate(lines)
+ if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
+ )
+
+ known = {
+ line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
+ }
+
+ signatures = []
+
+ for name, obj in globals().items():
+ if name in known:
+ continue
+ if not isinstance(obj, _CFuncPtr):
+ continue
+
+ params = []
+ for i, t in enumerate(obj.argtypes):
+ params.append(f"arg{i + 1}: {type2str(name, i, t)}")
+
+ resname = type2str(name, None, obj.restype)
+
+ signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...")
+
+ lines[istart + 1 : iend] = signatures
+
+ with open(fn, "w") as f:
+ f.write("\n".join(lines))
+ f.write("\n")
+
+
+if __name__ == "__main__":
+ generate_stub()
diff --git a/psycopg/psycopg/pq/_pq_ctypes.pyi b/psycopg/psycopg/pq/_pq_ctypes.pyi
new file mode 100644
index 0000000..5d2ee3f
--- /dev/null
+++ b/psycopg/psycopg/pq/_pq_ctypes.pyi
@@ -0,0 +1,216 @@
+"""
+types stub for ctypes functions
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Optional, Sequence
+from ctypes import Array, pointer, _Pointer
+from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
+
+class FILE: ...
+
+def fdopen(fd: int, mode: bytes) -> _Pointer[FILE]: ... # type: ignore[type-var]
+
+Oid = c_uint
+
+class PGconn_struct: ...
+class PGresult_struct: ...
+class PGcancel_struct: ...
+
+class PQconninfoOption_struct:
+ keyword: bytes
+ envvar: bytes
+ compiled: bytes
+ val: bytes
+ label: bytes
+ dispchar: bytes
+ dispsize: int
+
+class PGnotify_struct:
+ be_pid: int
+ relname: bytes
+ extra: bytes
+
+class PGresAttDesc_struct:
+ name: bytes
+ tableid: int
+ columnid: int
+ format: int
+ typid: int
+ typlen: int
+ atttypmod: int
+
+def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQexecPrepared(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: int,
+ arg4: Optional[Array[c_char_p]],
+ arg5: Optional[Array[c_int]],
+ arg6: Optional[Array[c_int]],
+ arg7: int,
+) -> PGresult_struct: ...
+def PQprepare(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: int,
+ arg5: Optional[Array[c_uint]],
+) -> PGresult_struct: ...
+def PQgetvalue(
+ arg1: Optional[PGresult_struct], arg2: int, arg3: int
+) -> _Pointer[c_char]: ...
+def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQescapeStringConn(
+ arg1: Optional[PGconn_struct],
+ arg2: c_char_p,
+ arg3: bytes,
+ arg4: int,
+ arg5: _Pointer[c_int],
+) -> int: ...
+def PQescapeString(arg1: c_char_p, arg2: bytes, arg3: int) -> int: ...
+def PQsendPrepare(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: int,
+ arg5: Optional[Array[c_uint]],
+) -> int: ...
+def PQsendQueryPrepared(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: int,
+ arg4: Optional[Array[c_char_p]],
+ arg5: Optional[Array[c_int]],
+ arg6: Optional[Array[c_int]],
+ arg7: int,
+) -> int: ...
+def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ...
+def PQsetNoticeReceiver(
+ arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
+) -> Callable[[Any], PGresult_struct]: ...
+
+# TODO: Ignoring type as getting an error on mypy/ctypes:
+# Type argument "psycopg.pq._pq_ctypes.PGnotify_struct" of "pointer" must be
+# a subtype of "ctypes._CData"
+def PQnotifies(
+ arg1: Optional[PGconn_struct],
+) -> Optional[_Pointer[PGnotify_struct]]: ... # type: ignore
+def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ...
+
+# Arg 2 is a _Pointer, reported as _CArgObject by mypy
+def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ...
+def PQsetResultAttrs(
+ arg1: Optional[PGresult_struct],
+ arg2: int,
+ arg3: Array[PGresAttDesc_struct], # type: ignore
+) -> int: ...
+def PQtrace(
+ arg1: Optional[PGconn_struct],
+ arg2: _Pointer[FILE], # type: ignore[type-var]
+) -> None: ...
+def PQencryptPasswordConn(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: Optional[bytes],
+) -> bytes: ...
+def PQpipelineStatus(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQenterPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQexitPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQpipelineSync(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQsendFlushRequest(pgconn: Optional[PGconn_struct]) -> int: ...
+
+# fmt: off
+# autogenerated: start
+def PQlibVersion() -> int: ...
+def PQconnectdb(arg1: bytes) -> PGconn_struct: ...
+def PQconnectStart(arg1: bytes) -> PGconn_struct: ...
+def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ...
+def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoParse(arg1: bytes, arg2: _Pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ...
+def PQfinish(arg1: Optional[PGconn_struct]) -> None: ...
+def PQreset(arg1: Optional[PGconn_struct]) -> None: ...
+def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ...
+def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQping(arg1: bytes) -> int: ...
+def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQstatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ...
+def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsocket(arg1: Optional[PGconn_struct]) -> int: ...
+def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ...
+def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> PGresult_struct: ...
+def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ...
+def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
+def PQclear(arg1: Optional[PGresult_struct]) -> None: ...
+def PQntuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQnfields(arg1: Optional[PGresult_struct]) -> int: ...
+def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
+def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
+def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
+def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQescapeBytea(arg1: bytes, arg2: int, arg3: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQunescapeBytea(arg1: bytes, arg2: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> int: ...
+def PQsendDescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQsendDescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ...
+def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ...
+def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
+def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
+def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ...
+def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
+def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
+def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
+def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ...
+def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ...
+def PQfreemem(arg1: Any) -> None: ...
+def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ...
+def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
+def _PQpipelineStatus(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQenterPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQexitPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQpipelineSync(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQsendFlushRequest(arg1: Optional[PGconn_struct]) -> int: ...
+def PQinitOpenSSL(arg1: int, arg2: int) -> None: ...
+# autogenerated: end
+# fmt: on
+
+# vim: set syntax=python:
diff --git a/psycopg/psycopg/pq/abc.py b/psycopg/psycopg/pq/abc.py
new file mode 100644
index 0000000..9c45f64
--- /dev/null
+++ b/psycopg/psycopg/pq/abc.py
@@ -0,0 +1,385 @@
+"""
+Protocol objects to represent objects exposed by different pq implementations.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+from typing import Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from ._enums import Format, Trace
+from .._compat import Protocol
+
+if TYPE_CHECKING:
+ from .misc import PGnotify, ConninfoOption, PGresAttDesc
+
+# An object implementing the buffer protocol (ish)
+Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
+
+
+class PGconn(Protocol):
+
+ notice_handler: Optional[Callable[["PGresult"], None]]
+ notify_handler: Optional[Callable[["PGnotify"], None]]
+
+ @classmethod
+ def connect(cls, conninfo: bytes) -> "PGconn":
+ ...
+
+ @classmethod
+ def connect_start(cls, conninfo: bytes) -> "PGconn":
+ ...
+
+ def connect_poll(self) -> int:
+ ...
+
+ def finish(self) -> None:
+ ...
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ ...
+
+ def reset(self) -> None:
+ ...
+
+ def reset_start(self) -> None:
+ ...
+
+ def reset_poll(self) -> int:
+ ...
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ ...
+
+ @property
+ def db(self) -> bytes:
+ ...
+
+ @property
+ def user(self) -> bytes:
+ ...
+
+ @property
+ def password(self) -> bytes:
+ ...
+
+ @property
+ def host(self) -> bytes:
+ ...
+
+ @property
+ def hostaddr(self) -> bytes:
+ ...
+
+ @property
+ def port(self) -> bytes:
+ ...
+
+ @property
+ def tty(self) -> bytes:
+ ...
+
+ @property
+ def options(self) -> bytes:
+ ...
+
+ @property
+ def status(self) -> int:
+ ...
+
+ @property
+ def transaction_status(self) -> int:
+ ...
+
+ def parameter_status(self, name: bytes) -> Optional[bytes]:
+ ...
+
+ @property
+ def error_message(self) -> bytes:
+ ...
+
+ @property
+ def server_version(self) -> int:
+ ...
+
+ @property
+ def socket(self) -> int:
+ ...
+
+ @property
+ def backend_pid(self) -> int:
+ ...
+
+ @property
+ def needs_password(self) -> bool:
+ ...
+
+ @property
+ def used_password(self) -> bool:
+ ...
+
+ @property
+ def ssl_in_use(self) -> bool:
+ ...
+
+ def exec_(self, command: bytes) -> "PGresult":
+ ...
+
+ def send_query(self, command: bytes) -> None:
+ ...
+
+ def exec_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> "PGresult":
+ ...
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ ...
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ ...
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ ...
+
+ def prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> "PGresult":
+ ...
+
+ def exec_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Buffer]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = 0,
+ ) -> "PGresult":
+ ...
+
+ def describe_prepared(self, name: bytes) -> "PGresult":
+ ...
+
+ def send_describe_prepared(self, name: bytes) -> None:
+ ...
+
+ def describe_portal(self, name: bytes) -> "PGresult":
+ ...
+
+ def send_describe_portal(self, name: bytes) -> None:
+ ...
+
+ def get_result(self) -> Optional["PGresult"]:
+ ...
+
+ def consume_input(self) -> None:
+ ...
+
+ def is_busy(self) -> int:
+ ...
+
+ @property
+ def nonblocking(self) -> int:
+ ...
+
+ @nonblocking.setter
+ def nonblocking(self, arg: int) -> None:
+ ...
+
+ def flush(self) -> int:
+ ...
+
+ def set_single_row_mode(self) -> None:
+ ...
+
+ def get_cancel(self) -> "PGcancel":
+ ...
+
+ def notifies(self) -> Optional["PGnotify"]:
+ ...
+
+ def put_copy_data(self, buffer: Buffer) -> int:
+ ...
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ ...
+
+ def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
+ ...
+
+ def trace(self, fileno: int) -> None:
+ ...
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ ...
+
+ def untrace(self) -> None:
+ ...
+
+ def encrypt_password(
+ self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
+ ) -> bytes:
+ ...
+
+ def make_empty_result(self, exec_status: int) -> "PGresult":
+ ...
+
+ @property
+ def pipeline_status(self) -> int:
+ ...
+
+ def enter_pipeline_mode(self) -> None:
+ ...
+
+ def exit_pipeline_mode(self) -> None:
+ ...
+
+ def pipeline_sync(self) -> None:
+ ...
+
+ def send_flush_request(self) -> None:
+ ...
+
+
+class PGresult(Protocol):
+ def clear(self) -> None:
+ ...
+
+ @property
+ def status(self) -> int:
+ ...
+
+ @property
+ def error_message(self) -> bytes:
+ ...
+
+ def error_field(self, fieldcode: int) -> Optional[bytes]:
+ ...
+
+ @property
+ def ntuples(self) -> int:
+ ...
+
+ @property
+ def nfields(self) -> int:
+ ...
+
+ def fname(self, column_number: int) -> Optional[bytes]:
+ ...
+
+ def ftable(self, column_number: int) -> int:
+ ...
+
+ def ftablecol(self, column_number: int) -> int:
+ ...
+
+ def fformat(self, column_number: int) -> int:
+ ...
+
+ def ftype(self, column_number: int) -> int:
+ ...
+
+ def fmod(self, column_number: int) -> int:
+ ...
+
+ def fsize(self, column_number: int) -> int:
+ ...
+
+ @property
+ def binary_tuples(self) -> int:
+ ...
+
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+ ...
+
+ @property
+ def nparams(self) -> int:
+ ...
+
+ def param_type(self, param_number: int) -> int:
+ ...
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ ...
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ ...
+
+ @property
+ def oid_value(self) -> int:
+ ...
+
+ def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None:
+ ...
+
+
+class PGcancel(Protocol):
+ def free(self) -> None:
+ ...
+
+ def cancel(self) -> None:
+ ...
+
+
+class Conninfo(Protocol):
+ @classmethod
+ def get_defaults(cls) -> List["ConninfoOption"]:
+ ...
+
+ @classmethod
+ def parse(cls, conninfo: bytes) -> List["ConninfoOption"]:
+ ...
+
+ @classmethod
+ def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
+ ...
+
+
+class Escaping(Protocol):
+ def __init__(self, conn: Optional[PGconn] = None):
+ ...
+
+ def escape_literal(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_identifier(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_string(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_bytea(self, data: Buffer) -> bytes:
+ ...
+
+ def unescape_bytea(self, data: Buffer) -> bytes:
+ ...
diff --git a/psycopg/psycopg/pq/misc.py b/psycopg/psycopg/pq/misc.py
new file mode 100644
index 0000000..3a43133
--- /dev/null
+++ b/psycopg/psycopg/pq/misc.py
@@ -0,0 +1,146 @@
+"""
+Various functionalities to make easier to work with the libpq.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import sys
+import logging
+import ctypes.util
+from typing import cast, NamedTuple, Optional, Union
+
+from .abc import PGconn, PGresult
+from ._enums import ConnStatus, TransactionStatus, PipelineStatus
+from .._compat import cache
+from .._encodings import pgconn_encoding
+
+logger = logging.getLogger("psycopg.pq")
+
+OK = ConnStatus.OK
+
+
+class PGnotify(NamedTuple):
+ relname: bytes
+ be_pid: int
+ extra: bytes
+
+
+class ConninfoOption(NamedTuple):
+ keyword: bytes
+ envvar: Optional[bytes]
+ compiled: Optional[bytes]
+ val: Optional[bytes]
+ label: bytes
+ dispchar: bytes
+ dispsize: int
+
+
+class PGresAttDesc(NamedTuple):
+ name: bytes
+ tableid: int
+ columnid: int
+ format: int
+ typid: int
+ typlen: int
+ atttypmod: int
+
+
+@cache
+def find_libpq_full_path() -> Optional[str]:
+ if sys.platform == "win32":
+ libname = ctypes.util.find_library("libpq.dll")
+
+ elif sys.platform == "darwin":
+ libname = ctypes.util.find_library("libpq.dylib")
+ # (hopefully) temporary hack: libpq not in a standard place
+ # https://github.com/orgs/Homebrew/discussions/3595
+ # If pg_config is available and agrees, let's use its indications.
+ if not libname:
+ try:
+ import subprocess as sp
+
+ libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
+ libname = os.path.join(libdir, "libpq.dylib")
+ if not os.path.exists(libname):
+ libname = None
+ except Exception as ex:
+ logger.debug("couldn't use pg_config to find libpq: %s", ex)
+
+ else:
+ libname = ctypes.util.find_library("pq")
+
+ return libname
+
+
+def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
+ """
+ Return an error message from a `PGconn` or `PGresult`.
+
+ The return value is a `!str` (unlike pq data which is usually `!bytes`):
+ use the connection encoding if available, otherwise the `!encoding`
+ parameter as a fallback for decoding. Don't raise exceptions on decoding
+ errors.
+
+ """
+ bmsg: bytes
+
+ if hasattr(obj, "error_field"):
+ # obj is a PGresult
+ obj = cast(PGresult, obj)
+ bmsg = obj.error_message
+
+ # strip severity and whitespaces
+ if bmsg:
+ bmsg = bmsg.split(b":", 1)[-1].strip()
+
+ elif hasattr(obj, "error_message"):
+ # obj is a PGconn
+ if obj.status == OK:
+ encoding = pgconn_encoding(obj)
+ bmsg = obj.error_message
+
+ # strip severity and whitespaces
+ if bmsg:
+ bmsg = bmsg.split(b":", 1)[-1].strip()
+
+ else:
+ raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
+
+ if bmsg:
+ msg = bmsg.decode(encoding, "replace")
+ else:
+ msg = "no details available"
+
+ return msg
+
+
+def connection_summary(pgconn: PGconn) -> str:
+ """
+ Return summary information on a connection.
+
+ Useful for __repr__
+ """
+ parts = []
+ if pgconn.status == OK:
+ # Put together the [STATUS]
+ status = TransactionStatus(pgconn.transaction_status).name
+ if pgconn.pipeline_status:
+ status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}"
+
+ # Put together the (CONNECTION)
+ if not pgconn.host.startswith(b"/"):
+ parts.append(("host", pgconn.host.decode()))
+ if pgconn.port != b"5432":
+ parts.append(("port", pgconn.port.decode()))
+ if pgconn.user != pgconn.db:
+ parts.append(("user", pgconn.user.decode()))
+ parts.append(("database", pgconn.db.decode()))
+
+ else:
+ status = ConnStatus(pgconn.status).name
+
+ sparts = " ".join("%s=%s" % part for part in parts)
+ if sparts:
+ sparts = f" ({sparts})"
+ return f"[{status}]{sparts}"
diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py
new file mode 100644
index 0000000..8b87c19
--- /dev/null
+++ b/psycopg/psycopg/pq/pq_ctypes.py
@@ -0,0 +1,1086 @@
+"""
+libpq Python wrapper using ctypes bindings.
+
+Clients shouldn't use this module directly, unless for testing: they should use
+the `pq` module instead, which is in charge of choosing the best
+implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import logging
+from os import getpid
+from weakref import ref
+
+from ctypes import Array, POINTER, cast, string_at, create_string_buffer, byref
+from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+from typing import cast as t_cast, TYPE_CHECKING
+
+from .. import errors as e
+from . import _pq_ctypes as impl
+from .misc import PGnotify, ConninfoOption, PGresAttDesc
+from .misc import error_message, connection_summary
+from ._enums import Format, ExecStatus, Trace
+
+# Imported locally to call them from __del__ methods
+from ._pq_ctypes import PQclear, PQfinish, PQfreeCancel, PQstatus
+
+if TYPE_CHECKING:
+ from . import abc
+
+__impl__ = "python"
+
+logger = logging.getLogger("psycopg")
+
+
+def version() -> int:
+ """Return the version number of the libpq currently loaded.
+
+ The number is in the same format of `~psycopg.ConnectionInfo.server_version`.
+
+ Certain features might not be available if the libpq library used is too old.
+ """
+ return impl.PQlibVersion()
+
+
+@impl.PQnoticeReceiver # type: ignore
+def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None:
+ pgconn = cast(arg, POINTER(py_object)).contents.value()
+ if not (pgconn and pgconn.notice_handler):
+ return
+
+ res = PGresult(result_ptr)
+ try:
+ pgconn.notice_handler(res)
+ except Exception as exc:
+ logger.exception("error in notice receiver: %s", exc)
+ finally:
+ res._pgresult_ptr = None # avoid destroying the pgresult_ptr
+
+
+class PGconn:
+ """
+ Python representation of a libpq connection.
+ """
+
+ __slots__ = (
+ "_pgconn_ptr",
+ "notice_handler",
+ "notify_handler",
+ "_self_ptr",
+ "_procpid",
+ "__weakref__",
+ )
+
+ def __init__(self, pgconn_ptr: impl.PGconn_struct):
+ self._pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
+ self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None
+ self.notify_handler: Optional[Callable[[PGnotify], None]] = None
+
+ # Keep alive for the lifetime of PGconn
+ self._self_ptr = py_object(ref(self))
+ impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr))
+
+ self._procpid = getpid()
+
+ def __del__(self) -> None:
+ # Close the connection only if it was created in this process,
+ # not if this object is being GC'd after fork.
+ if getpid() == self._procpid:
+ self.finish()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @classmethod
+ def connect(cls, conninfo: bytes) -> "PGconn":
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ pgconn_ptr = impl.PQconnectdb(conninfo)
+ if not pgconn_ptr:
+ raise MemoryError("couldn't allocate PGconn")
+ return cls(pgconn_ptr)
+
+ @classmethod
+ def connect_start(cls, conninfo: bytes) -> "PGconn":
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ pgconn_ptr = impl.PQconnectStart(conninfo)
+ if not pgconn_ptr:
+ raise MemoryError("couldn't allocate PGconn")
+ return cls(pgconn_ptr)
+
+ def connect_poll(self) -> int:
+ return self._call_int(impl.PQconnectPoll)
+
+ def finish(self) -> None:
+ self._pgconn_ptr, p = None, self._pgconn_ptr
+ if p:
+ PQfinish(p)
+
+ @property
+ def pgconn_ptr(self) -> Optional[int]:
+ """The pointer to the underlying `!PGconn` structure, as integer.
+
+ `!None` if the connection is closed.
+
+ The value can be used to pass the structure to libpq functions which
+ psycopg doesn't (currently) wrap, either in C or in Python using FFI
+ libraries such as `ctypes`.
+ """
+ if self._pgconn_ptr is None:
+ return None
+
+ return addressof(self._pgconn_ptr.contents) # type: ignore[attr-defined]
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ self._ensure_pgconn()
+ opts = impl.PQconninfo(self._pgconn_ptr)
+ if not opts:
+ raise MemoryError("couldn't allocate connection info")
+ try:
+ return Conninfo._options_from_array(opts)
+ finally:
+ impl.PQconninfoFree(opts)
+
+ def reset(self) -> None:
+ self._ensure_pgconn()
+ impl.PQreset(self._pgconn_ptr)
+
+ def reset_start(self) -> None:
+ if not impl.PQresetStart(self._pgconn_ptr):
+ raise e.OperationalError("couldn't reset connection")
+
+ def reset_poll(self) -> int:
+ return self._call_int(impl.PQresetPoll)
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ return impl.PQping(conninfo)
+
+ @property
+ def db(self) -> bytes:
+ return self._call_bytes(impl.PQdb)
+
+ @property
+ def user(self) -> bytes:
+ return self._call_bytes(impl.PQuser)
+
+ @property
+ def password(self) -> bytes:
+ return self._call_bytes(impl.PQpass)
+
+ @property
+ def host(self) -> bytes:
+ return self._call_bytes(impl.PQhost)
+
+ @property
+ def hostaddr(self) -> bytes:
+ return self._call_bytes(impl.PQhostaddr)
+
+ @property
+ def port(self) -> bytes:
+ return self._call_bytes(impl.PQport)
+
+ @property
+ def tty(self) -> bytes:
+ return self._call_bytes(impl.PQtty)
+
+ @property
+ def options(self) -> bytes:
+ return self._call_bytes(impl.PQoptions)
+
+ @property
+ def status(self) -> int:
+ return PQstatus(self._pgconn_ptr)
+
+ @property
+ def transaction_status(self) -> int:
+ return impl.PQtransactionStatus(self._pgconn_ptr)
+
+ def parameter_status(self, name: bytes) -> Optional[bytes]:
+ self._ensure_pgconn()
+ return impl.PQparameterStatus(self._pgconn_ptr, name)
+
+ @property
+ def error_message(self) -> bytes:
+ return impl.PQerrorMessage(self._pgconn_ptr)
+
+ @property
+ def protocol_version(self) -> int:
+ return self._call_int(impl.PQprotocolVersion)
+
+ @property
+ def server_version(self) -> int:
+ return self._call_int(impl.PQserverVersion)
+
+ @property
+ def socket(self) -> int:
+ rv = self._call_int(impl.PQsocket)
+ if rv == -1:
+ raise e.OperationalError("the connection is lost")
+ return rv
+
+ @property
+ def backend_pid(self) -> int:
+ return self._call_int(impl.PQbackendPID)
+
+ @property
+ def needs_password(self) -> bool:
+ """True if the connection authentication method required a password,
+ but none was available.
+
+ See :pq:`PQconnectionNeedsPassword` for details.
+ """
+ return bool(impl.PQconnectionNeedsPassword(self._pgconn_ptr))
+
+ @property
+ def used_password(self) -> bool:
+ """True if the connection authentication method used a password.
+
+ See :pq:`PQconnectionUsedPassword` for details.
+ """
+ return bool(impl.PQconnectionUsedPassword(self._pgconn_ptr))
+
+ @property
+ def ssl_in_use(self) -> bool:
+ return self._call_bool(impl.PQsslInUse)
+
+ def exec_(self, command: bytes) -> "PGresult":
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQexec(self._pgconn_ptr, command)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_query(self, command: bytes) -> None:
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendQuery(self._pgconn_ptr, command):
+ raise e.OperationalError(f"sending query failed: {error_message(self)}")
+
+ def exec_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> "PGresult":
+ args = self._query_params_args(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._ensure_pgconn()
+ rv = impl.PQexecParams(*args)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ args = self._query_params_args(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._ensure_pgconn()
+ if not impl.PQsendQueryParams(*args):
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ atypes: Optional[Array[impl.Oid]]
+ if not param_types:
+ nparams = 0
+ atypes = None
+ else:
+ nparams = len(param_types)
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ self._ensure_pgconn()
+ if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes):
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ # repurpose this function with a cheeky replacement of query with name,
+ # drop the param_types from the result
+ args = self._query_params_args(
+ name, param_values, None, param_formats, result_format
+ )
+ args = args[:3] + args[4:]
+
+ self._ensure_pgconn()
+ if not impl.PQsendQueryPrepared(*args):
+ raise e.OperationalError(
+ f"sending prepared query failed: {error_message(self)}"
+ )
+
+ def _query_params_args(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> Any:
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
+ if param_values:
+ nparams = len(param_values)
+ aparams = (c_char_p * nparams)(
+ *(
+ # convert bytearray/memoryview to bytes
+ b if b is None or isinstance(b, bytes) else bytes(b)
+ for b in param_values
+ )
+ )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
+ else:
+ nparams = 0
+ aparams = alenghts = None
+
+ atypes: Optional[Array[impl.Oid]]
+ if not param_types:
+ atypes = None
+ else:
+ if len(param_types) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(param_types))
+ )
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ if not param_formats:
+ aformats = None
+ else:
+ if len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_formats"
+ % (nparams, len(param_formats))
+ )
+ aformats = (c_int * nparams)(*param_formats)
+
+ return (
+ self._pgconn_ptr,
+ command,
+ nparams,
+ atypes,
+ aparams,
+ alenghts,
+ aformats,
+ result_format,
+ )
+
+ def prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+
+ if not isinstance(command, bytes):
+ raise TypeError(f"'command' must be bytes, got {type(command)} instead")
+
+ if not param_types:
+ nparams = 0
+ atypes = None
+ else:
+ nparams = len(param_types)
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ self._ensure_pgconn()
+ rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def exec_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence["abc.Buffer"]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = 0,
+ ) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
+ if param_values:
+ nparams = len(param_values)
+ aparams = (c_char_p * nparams)(
+ *(
+ # convert bytearray/memoryview to bytes
+ b if b is None or isinstance(b, bytes) else bytes(b)
+ for b in param_values
+ )
+ )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
+ else:
+ nparams = 0
+ aparams = alenghts = None
+
+ if not param_formats:
+ aformats = None
+ else:
+ if len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(param_formats))
+ )
+ aformats = (c_int * nparams)(*param_formats)
+
+ self._ensure_pgconn()
+ rv = impl.PQexecPrepared(
+ self._pgconn_ptr,
+ name,
+ nparams,
+ aparams,
+ alenghts,
+ aformats,
+ result_format,
+ )
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def describe_prepared(self, name: bytes) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQdescribePrepared(self._pgconn_ptr, name)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_describe_prepared(self, name: bytes) -> None:
+ if not isinstance(name, bytes):
+ raise TypeError(f"bytes expected, got {type(name)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendDescribePrepared(self._pgconn_ptr, name):
+ raise e.OperationalError(
+ f"sending describe prepared failed: {error_message(self)}"
+ )
+
+ def describe_portal(self, name: bytes) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQdescribePortal(self._pgconn_ptr, name)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_describe_portal(self, name: bytes) -> None:
+ if not isinstance(name, bytes):
+ raise TypeError(f"bytes expected, got {type(name)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendDescribePortal(self._pgconn_ptr, name):
+ raise e.OperationalError(
+ f"sending describe portal failed: {error_message(self)}"
+ )
+
+ def get_result(self) -> Optional["PGresult"]:
+ rv = impl.PQgetResult(self._pgconn_ptr)
+ return PGresult(rv) if rv else None
+
+ def consume_input(self) -> None:
+ if 1 != impl.PQconsumeInput(self._pgconn_ptr):
+ raise e.OperationalError(f"consuming input failed: {error_message(self)}")
+
+ def is_busy(self) -> int:
+ return impl.PQisBusy(self._pgconn_ptr)
+
+ @property
+ def nonblocking(self) -> int:
+ return impl.PQisnonblocking(self._pgconn_ptr)
+
+ @nonblocking.setter
+ def nonblocking(self, arg: int) -> None:
+ if 0 > impl.PQsetnonblocking(self._pgconn_ptr, arg):
+ raise e.OperationalError(
+ f"setting nonblocking failed: {error_message(self)}"
+ )
+
+ def flush(self) -> int:
+ # PQflush segfaults if it receives a NULL connection
+ if not self._pgconn_ptr:
+ raise e.OperationalError("flushing failed: the connection is closed")
+ rv: int = impl.PQflush(self._pgconn_ptr)
+ if rv < 0:
+ raise e.OperationalError(f"flushing failed: {error_message(self)}")
+ return rv
+
+ def set_single_row_mode(self) -> None:
+ if not impl.PQsetSingleRowMode(self._pgconn_ptr):
+ raise e.OperationalError("setting single row mode failed")
+
+ def get_cancel(self) -> "PGcancel":
+ """
+ Create an object with the information needed to cancel a command.
+
+ See :pq:`PQgetCancel` for details.
+ """
+ rv = impl.PQgetCancel(self._pgconn_ptr)
+ if not rv:
+ raise e.OperationalError("couldn't create cancel object")
+ return PGcancel(rv)
+
+ def notifies(self) -> Optional[PGnotify]:
+ ptr = impl.PQnotifies(self._pgconn_ptr)
+ if ptr:
+ c = ptr.contents
+ return PGnotify(c.relname, c.be_pid, c.extra)
+ impl.PQfreemem(ptr)
+ else:
+ return None
+
+ def put_copy_data(self, buffer: "abc.Buffer") -> int:
+ if not isinstance(buffer, bytes):
+ buffer = bytes(buffer)
+ rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
+ if rv < 0:
+ raise e.OperationalError(f"sending copy data failed: {error_message(self)}")
+ return rv
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ rv = impl.PQputCopyEnd(self._pgconn_ptr, error)
+ if rv < 0:
+ raise e.OperationalError(f"sending copy end failed: {error_message(self)}")
+ return rv
+
+ def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
+ buffer_ptr = c_char_p()
+ nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_)
+ if nbytes == -2:
+ raise e.OperationalError(
+ f"receiving copy data failed: {error_message(self)}"
+ )
+ if buffer_ptr:
+ # TODO: do it without copy
+ data = string_at(buffer_ptr, nbytes)
+ impl.PQfreemem(buffer_ptr)
+ return nbytes, memoryview(data)
+ else:
+ return nbytes, memoryview(b"")
+
+ def trace(self, fileno: int) -> None:
+ """
+ Enable tracing of the client/server communication to a file stream.
+
+ See :pq:`PQtrace` for details.
+ """
+ if sys.platform != "linux":
+ raise e.NotSupportedError("currently only supported on Linux")
+ stream = impl.fdopen(fileno, b"w")
+ impl.PQtrace(self._pgconn_ptr, stream)
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ """
+ Configure tracing behavior of client/server communication.
+
+ :param flags: operating mode of tracing.
+
+ See :pq:`PQsetTraceFlags` for details.
+ """
+ impl.PQsetTraceFlags(self._pgconn_ptr, flags)
+
+ def untrace(self) -> None:
+ """
+ Disable tracing, previously enabled through `trace()`.
+
+ See :pq:`PQuntrace` for details.
+ """
+ impl.PQuntrace(self._pgconn_ptr)
+
+ def encrypt_password(
+ self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
+ ) -> bytes:
+ """
+ Return the encrypted form of a PostgreSQL password.
+
+ See :pq:`PQencryptPasswordConn` for details.
+ """
+ out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm)
+ if not out:
+ raise e.OperationalError(
+ f"password encryption failed: {error_message(self)}"
+ )
+
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def make_empty_result(self, exec_status: int) -> "PGresult":
+ rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)
+ if not rv:
+ raise MemoryError("couldn't allocate empty PGresult")
+ return PGresult(rv)
+
+ @property
+ def pipeline_status(self) -> int:
+ if version() < 140000:
+ return 0
+ return impl.PQpipelineStatus(self._pgconn_ptr)
+
+ def enter_pipeline_mode(self) -> None:
+ """Enter pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to enter the pipeline
+ mode.
+ """
+ if impl.PQenterPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError("failed to enter pipeline mode")
+
+ def exit_pipeline_mode(self) -> None:
+ """Exit pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to exit the pipeline
+ mode.
+ """
+ if impl.PQexitPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError(error_message(self))
+
+ def pipeline_sync(self) -> None:
+ """Mark a synchronization point in a pipeline.
+
+ :raises ~e.OperationalError: if the connection is not in pipeline mode
+ or if sync failed.
+ """
+ rv = impl.PQpipelineSync(self._pgconn_ptr)
+ if rv == 0:
+ raise e.OperationalError("connection not in pipeline mode")
+ if rv != 1:
+ raise e.OperationalError("failed to sync pipeline")
+
+ def send_flush_request(self) -> None:
+ """Sends a request for the server to flush its output buffer.
+
+ :raises ~e.OperationalError: if the flush request failed.
+ """
+ if impl.PQsendFlushRequest(self._pgconn_ptr) == 0:
+ raise e.OperationalError(f"flush request failed: {error_message(self)}")
+
+ def _call_bytes(
+ self, func: Callable[[impl.PGconn_struct], Optional[bytes]]
+ ) -> bytes:
+ """
+ Call one of the pgconn libpq functions returning a bytes pointer.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ rv = func(self._pgconn_ptr)
+ assert rv is not None
+ return rv
+
+ def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int:
+ """
+ Call one of the pgconn libpq functions returning an int.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ return func(self._pgconn_ptr)
+
+ def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool:
+ """
+ Call one of the pgconn libpq functions returning a logical value.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ return bool(func(self._pgconn_ptr))
+
+ def _ensure_pgconn(self) -> None:
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+
+
+class PGresult:
+ """
+ Python representation of a libpq result.
+ """
+
+ __slots__ = ("_pgresult_ptr",)
+
+ def __init__(self, pgresult_ptr: impl.PGresult_struct):
+ self._pgresult_ptr: Optional[impl.PGresult_struct] = pgresult_ptr
+
+ def __del__(self) -> None:
+ self.clear()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ status = ExecStatus(self.status)
+ return f"<{cls} [{status.name}] at 0x{id(self):x}>"
+
+ def clear(self) -> None:
+ self._pgresult_ptr, p = None, self._pgresult_ptr
+ if p:
+ PQclear(p)
+
+ @property
+ def pgresult_ptr(self) -> Optional[int]:
+ """The pointer to the underlying `!PGresult` structure, as integer.
+
+ `!None` if the result was cleared.
+
+ The value can be used to pass the structure to libpq functions which
+ psycopg doesn't (currently) wrap, either in C or in Python using FFI
+ libraries such as `ctypes`.
+ """
+ if self._pgresult_ptr is None:
+ return None
+
+ return addressof(self._pgresult_ptr.contents) # type: ignore[attr-defined]
+
+ @property
+ def status(self) -> int:
+ return impl.PQresultStatus(self._pgresult_ptr)
+
+ @property
+ def error_message(self) -> bytes:
+ return impl.PQresultErrorMessage(self._pgresult_ptr)
+
+ def error_field(self, fieldcode: int) -> Optional[bytes]:
+ return impl.PQresultErrorField(self._pgresult_ptr, fieldcode)
+
+ @property
+ def ntuples(self) -> int:
+ return impl.PQntuples(self._pgresult_ptr)
+
+ @property
+ def nfields(self) -> int:
+ return impl.PQnfields(self._pgresult_ptr)
+
+ def fname(self, column_number: int) -> Optional[bytes]:
+ return impl.PQfname(self._pgresult_ptr, column_number)
+
+ def ftable(self, column_number: int) -> int:
+ return impl.PQftable(self._pgresult_ptr, column_number)
+
+ def ftablecol(self, column_number: int) -> int:
+ return impl.PQftablecol(self._pgresult_ptr, column_number)
+
+ def fformat(self, column_number: int) -> int:
+ return impl.PQfformat(self._pgresult_ptr, column_number)
+
+ def ftype(self, column_number: int) -> int:
+ return impl.PQftype(self._pgresult_ptr, column_number)
+
+ def fmod(self, column_number: int) -> int:
+ return impl.PQfmod(self._pgresult_ptr, column_number)
+
+ def fsize(self, column_number: int) -> int:
+ return impl.PQfsize(self._pgresult_ptr, column_number)
+
+ @property
+ def binary_tuples(self) -> int:
+ return impl.PQbinaryTuples(self._pgresult_ptr)
+
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+ length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number)
+ if length:
+ v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number)
+ return string_at(v, length)
+ else:
+ if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number):
+ return None
+ else:
+ return b""
+
+ @property
+ def nparams(self) -> int:
+ return impl.PQnparams(self._pgresult_ptr)
+
+ def param_type(self, param_number: int) -> int:
+ return impl.PQparamtype(self._pgresult_ptr, param_number)
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ return impl.PQcmdStatus(self._pgresult_ptr)
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ rv = impl.PQcmdTuples(self._pgresult_ptr)
+ return int(rv) if rv else None
+
+ @property
+ def oid_value(self) -> int:
+ return impl.PQoidValue(self._pgresult_ptr)
+
+ def set_attributes(self, descriptions: List[PGresAttDesc]) -> None:
+ structs = [
+ impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore
+ ]
+ array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore
+ rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array)
+ if rv == 0:
+ raise e.OperationalError("PQsetResultAttrs failed")
+
+
+class PGcancel:
+ """
+ Token to cancel the current operation on a connection.
+
+ Created by `PGconn.get_cancel()`.
+ """
+
+ __slots__ = ("pgcancel_ptr",)
+
+ def __init__(self, pgcancel_ptr: impl.PGcancel_struct):
+ self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr
+
+ def __del__(self) -> None:
+ self.free()
+
+ def free(self) -> None:
+ """
+ Free the data structure created by :pq:`PQgetCancel()`.
+
+ Automatically invoked by `!__del__()`.
+
+ See :pq:`PQfreeCancel()` for details.
+ """
+ self.pgcancel_ptr, p = None, self.pgcancel_ptr
+ if p:
+ PQfreeCancel(p)
+
+ def cancel(self) -> None:
+ """Requests that the server abandon processing of the current command.
+
+ See :pq:`PQcancel()` for details.
+ """
+ buf = create_string_buffer(256)
+ res = impl.PQcancel(
+ self.pgcancel_ptr,
+ byref(buf), # type: ignore[arg-type]
+ len(buf),
+ )
+ if not res:
+ raise e.OperationalError(
+ f"cancel failed: {buf.value.decode('utf8', 'ignore')}"
+ )
+
+
+class Conninfo:
+ """
+ Utility object to manipulate connection strings.
+ """
+
+ @classmethod
+ def get_defaults(cls) -> List[ConninfoOption]:
+ opts = impl.PQconndefaults()
+ if not opts:
+ raise MemoryError("couldn't allocate connection defaults")
+ try:
+ return cls._options_from_array(opts)
+ finally:
+ impl.PQconninfoFree(opts)
+
+ @classmethod
+ def parse(cls, conninfo: bytes) -> List[ConninfoOption]:
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ errmsg = c_char_p()
+ rv = impl.PQconninfoParse(conninfo, byref(errmsg)) # type: ignore[arg-type]
+ if not rv:
+ if not errmsg:
+ raise MemoryError("couldn't allocate on conninfo parse")
+ else:
+ exc = e.OperationalError(
+ (errmsg.value or b"").decode("utf8", "replace")
+ )
+ impl.PQfreemem(errmsg)
+ raise exc
+
+ try:
+ return cls._options_from_array(rv)
+ finally:
+ impl.PQconninfoFree(rv)
+
+ @classmethod
+ def _options_from_array(
+ cls, opts: Sequence[impl.PQconninfoOption_struct]
+ ) -> List[ConninfoOption]:
+ rv = []
+ skws = "keyword envvar compiled val label dispchar".split()
+ for opt in opts:
+ if not opt.keyword:
+ break
+ d = {kw: getattr(opt, kw) for kw in skws}
+ d["dispsize"] = opt.dispsize
+ rv.append(ConninfoOption(**d))
+
+ return rv
+
+
+class Escaping:
+ """
+ Utility object to escape strings for SQL interpolation.
+ """
+
+ def __init__(self, conn: Optional[PGconn] = None):
+ self.conn = conn
+
+ def escape_literal(self, data: "abc.Buffer") -> bytes:
+ if not self.conn:
+ raise e.OperationalError("escape_literal failed: no connection provided")
+
+ self.conn._ensure_pgconn()
+ # TODO: might be done without copy (however C does that)
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))
+ if not out:
+ raise e.OperationalError(
+ f"escape_literal failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def escape_identifier(self, data: "abc.Buffer") -> bytes:
+ if not self.conn:
+ raise e.OperationalError("escape_identifier failed: no connection provided")
+
+ self.conn._ensure_pgconn()
+
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))
+ if not out:
+ raise e.OperationalError(
+ f"escape_identifier failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def escape_string(self, data: "abc.Buffer") -> bytes:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+
+ if self.conn:
+ self.conn._ensure_pgconn()
+ error = c_int()
+ out = create_string_buffer(len(data) * 2 + 1)
+ impl.PQescapeStringConn(
+ self.conn._pgconn_ptr,
+ byref(out), # type: ignore[arg-type]
+ data,
+ len(data),
+ byref(error), # type: ignore[arg-type]
+ )
+
+ if error:
+ raise e.OperationalError(
+ f"escape_string failed: {error_message(self.conn)} bytes"
+ )
+
+ else:
+ out = create_string_buffer(len(data) * 2 + 1)
+ impl.PQescapeString(
+ byref(out), # type: ignore[arg-type]
+ data,
+ len(data),
+ )
+
+ return out.value
+
+ def escape_bytea(self, data: "abc.Buffer") -> bytes:
+ len_out = c_size_t()
+ # TODO: might be able to do without a copy but it's a mess.
+ # the C library does it better anyway, so maybe not worth optimising
+ # https://mail.python.org/pipermail/python-dev/2012-September/121780.html
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ if self.conn:
+ self.conn._ensure_pgconn()
+ out = impl.PQescapeByteaConn(
+ self.conn._pgconn_ptr,
+ data,
+ len(data),
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ else:
+ out = impl.PQescapeBytea(
+ data,
+ len(data),
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value - 1) # out includes final 0
+ impl.PQfreemem(out)
+ return rv
+
+ def unescape_bytea(self, data: "abc.Buffer") -> bytes:
+ # not needed, but let's keep it symmetric with the escaping:
+ # if a connection is passed in, it must be valid.
+ if self.conn:
+ self.conn._ensure_pgconn()
+
+ len_out = c_size_t()
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQunescapeBytea(
+ data,
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value)
+ impl.PQfreemem(out)
+ return rv
+
+
+# importing the ssl module sets up Python's libcrypto callbacks
+import ssl # noqa
+
+# disable libcrypto setup in libpq, so it won't stomp on the callbacks
+# that have already been set up
+impl.PQinitOpenSSL(1, 0)
+
+__build_version__ = version()
diff --git a/psycopg/psycopg/py.typed b/psycopg/psycopg/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/psycopg/psycopg/py.typed
diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py
new file mode 100644
index 0000000..cb28b57
--- /dev/null
+++ b/psycopg/psycopg/rows.py
@@ -0,0 +1,256 @@
+"""
+psycopg row factories
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import functools
+from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
+from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar
+from collections import namedtuple
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from ._compat import Protocol
+from ._encodings import _as_python_identifier
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor, Cursor
+ from .cursor_async import AsyncCursor
+ from psycopg.pq.abc import PGresult
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+
+T = TypeVar("T", covariant=True)
+
+# Row factories
+
+Row = TypeVar("Row", covariant=True)
+
+
+class RowMaker(Protocol[Row]):
+ """
+ Callable protocol taking a sequence of value and returning an object.
+
+ The sequence of value is what is returned from a database query, already
+ adapted to the right Python types. The return value is the object that your
+ program would like to receive: by default (`tuple_row()`) it is a simple
+ tuple, but it may be any type of object.
+
+ Typically, `!RowMaker` functions are returned by `RowFactory`.
+ """
+
+ def __call__(self, __values: Sequence[Any]) -> Row:
+ ...
+
+
+class RowFactory(Protocol[Row]):
+ """
+ Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`.
+
+ A `!RowFactory` is typically called when a `!Cursor` receives a result.
+ This way it can inspect the cursor state (for instance the
+ `~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create
+ a complete object.
+
+ For instance the `dict_row()` `!RowFactory` uses the names of the column to
+ define the dictionary key and returns a `!RowMaker` function which would
+ use the values to create a dictionary for each record.
+ """
+
+ def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]:
+ ...
+
+
+class AsyncRowFactory(Protocol[Row]):
+ """
+ Like `RowFactory`, taking an async cursor as argument.
+ """
+
+ def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]:
+ ...
+
+
+class BaseRowFactory(Protocol[Row]):
+ """
+ Like `RowFactory`, taking either type of cursor as argument.
+ """
+
+ def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]:
+ ...
+
+
+TupleRow: TypeAlias = Tuple[Any, ...]
+"""
+An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
+"""
+
+
+DictRow: TypeAlias = Dict[str, Any]
+"""
+An alias for the type returned by `dict_row()`
+
+A `!DictRow` is a dictionary with keys as string and any value returned by the
+database.
+"""
+
+
+def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]":
+ r"""Row factory to represent rows as simple tuples.
+
+ This is the default factory, used when `~psycopg.Connection.connect()` or
+ `~psycopg.Connection.cursor()` are called without a `!row_factory`
+ parameter.
+
+ """
+ # Implementation detail: make sure this is the tuple type itself, not an
+ # equivalent function, because the C code fast-paths on it.
+ return tuple
+
+
+def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]":
+ """Row factory to represent rows as dictionaries.
+
+ The dictionary keys are taken from the column names of the returned columns.
+ """
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def dict_row_(values: Sequence[Any]) -> Dict[str, Any]:
+ # https://github.com/python/mypy/issues/2608
+ return dict(zip(names, values)) # type: ignore[arg-type]
+
+ return dict_row_
+
+
+def namedtuple_row(
+ cursor: "BaseCursor[Any, Any]",
+) -> "RowMaker[NamedTuple]":
+ """Row factory to represent rows as `~collections.namedtuple`.
+
+ The field names are taken from the column names of the returned columns,
+ with some mangling to deal with invalid names.
+ """
+ res = cursor.pgresult
+ if not res:
+ return no_result
+
+ nfields = _get_nfields(res)
+ if nfields is None:
+ return no_result
+
+ nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
+ return nt._make
+
+
+@functools.lru_cache(512)
+def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]:
+ snames = tuple(_as_python_identifier(n.decode(enc)) for n in names)
+ return namedtuple("Row", snames) # type: ignore[return-value]
+
+
+def class_row(cls: Type[T]) -> BaseRowFactory[T]:
+ r"""Generate a row factory to represent rows as instances of the class `!cls`.
+
+ The class must support every output column name as a keyword parameter.
+
+ :param cls: The class to return for each row. It must support the fields
+ returned by the query as keyword arguments.
+ :rtype: `!Callable[[Cursor],` `RowMaker`\[~T]]
+ """
+
+ def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]":
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def class_row__(values: Sequence[Any]) -> T:
+ return cls(**dict(zip(names, values))) # type: ignore[arg-type]
+
+ return class_row__
+
+ return class_row_
+
+
+def args_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+ """Generate a row factory calling `!func` with positional parameters for every row.
+
+ :param func: The function to call for each row. It must support the fields
+ returned by the query as positional arguments.
+ """
+
+ def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]":
+ def args_row__(values: Sequence[Any]) -> T:
+ return func(*values)
+
+ return args_row__
+
+ return args_row_
+
+
+def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+ """Generate a row factory calling `!func` with keyword parameters for every row.
+
+ :param func: The function to call for each row. It must support the fields
+ returned by the query as keyword arguments.
+ """
+
+ def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]":
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def kwargs_row__(values: Sequence[Any]) -> T:
+ return func(**dict(zip(names, values))) # type: ignore[arg-type]
+
+ return kwargs_row__
+
+ return kwargs_row_
+
+
+def no_result(values: Sequence[Any]) -> NoReturn:
+ """A `RowMaker` that always fail.
+
+ It can be used as return value for a `RowFactory` called with no result.
+ Note that the `!RowFactory` *will* be called with no result, but the
+ resulting `!RowMaker` never should.
+ """
+ raise e.InterfaceError("the cursor doesn't have a result")
+
+
+def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]:
+ res = cursor.pgresult
+ if not res:
+ return None
+
+ nfields = _get_nfields(res)
+ if nfields is None:
+ return None
+
+ enc = cursor._encoding
+ return [
+ res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
+ ]
+
+
+def _get_nfields(res: "PGresult") -> Optional[int]:
+ """
+ Return the number of columns in a result, if it returns tuples else None
+
+ Take into account the special case of results with zero columns.
+ """
+ nfields = res.nfields
+
+ if (
+ res.status == TUPLES_OK
+ or res.status == SINGLE_TUPLE
+ # "describe" in named cursors
+ or (res.status == COMMAND_OK and nfields)
+ ):
+ return nfields
+ else:
+ return None
diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py
new file mode 100644
index 0000000..b890d77
--- /dev/null
+++ b/psycopg/psycopg/server_cursor.py
@@ -0,0 +1,479 @@
+"""
+psycopg server-side cursor objects.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, AsyncIterator, List, Iterable, Iterator
+from typing import Optional, TypeVar, TYPE_CHECKING, overload
+from warnings import warn
+
+from . import pq
+from . import sql
+from . import errors as e
+from .abc import ConnectionType, Query, Params, PQGen
+from .rows import Row, RowFactory, AsyncRowFactory
+from .cursor import BaseCursor, Cursor
+from .generators import execute
+from .cursor_async import AsyncCursor
+
+if TYPE_CHECKING:
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+
+DEFAULT_ITERSIZE = 100
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+
+class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
+ """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
+
+ __slots__ = "_name _scrollable _withhold _described itersize _format".split()
+
+ def __init__(
+ self,
+ name: str,
+ scrollable: Optional[bool],
+ withhold: bool,
+ ):
+ self._name = name
+ self._scrollable = scrollable
+ self._withhold = withhold
+ self._described = False
+ self.itersize: int = DEFAULT_ITERSIZE
+ self._format = TEXT
+
+ def __repr__(self) -> str:
+ # Insert the name as the second word
+ parts = super().__repr__().split(None, 1)
+ parts.insert(1, f"{self._name!r}")
+ return " ".join(parts)
+
+ @property
+ def name(self) -> str:
+ """The name of the cursor."""
+ return self._name
+
+ @property
+ def scrollable(self) -> Optional[bool]:
+ """
+ Whether the cursor is scrollable or not.
+
+ If `!None` leave the choice to the server. Use `!True` if you want to
+ use `scroll()` on the cursor.
+ """
+ return self._scrollable
+
+ @property
+ def withhold(self) -> bool:
+ """
+ If the cursor can be used after the creating transaction has committed.
+ """
+ return self._withhold
+
+ @property
+ def rownumber(self) -> Optional[int]:
+ """Index of the next row to fetch in the current result.
+
+ `!None` if there is no result to fetch.
+ """
+ res = self.pgresult
+ # command_status is empty if the result comes from
+ # describe_portal, which means that we have just executed the DECLARE,
+ # so we can assume we are at the first row.
+ tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
+ return self._pos if tuples else None
+
+ def _declare_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator implementing `ServerCursor.execute()`."""
+
+ query = self._make_declare_statement(query)
+
+ # If the cursor is being reused, the previous one must be closed.
+ if self._described:
+ yield from self._close_gen()
+ self._described = False
+
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ self._execute_send(pgq, force_extended=True)
+ results = yield from execute(self._conn.pgconn)
+ if results[-1].status != COMMAND_OK:
+ self._raise_for_result(results[-1])
+
+ # Set the format, which will be used by describe and fetch operations
+ if binary is None:
+ self._format = self.format
+ else:
+ self._format = BINARY if binary else TEXT
+
+ # The above result only returned COMMAND_OK. Get the cursor shape
+ yield from self._describe_gen()
+
+ def _describe_gen(self) -> PQGen[None]:
+ self._pgconn.send_describe_portal(self._name.encode(self._encoding))
+ results = yield from execute(self._pgconn)
+ self._check_results(results)
+ self._results = results
+ self._select_current_result(0, format=self._format)
+ self._described = True
+
+ def _close_gen(self) -> PQGen[None]:
+ ts = self._conn.pgconn.transaction_status
+
+ # if the connection is not in a sane state, don't even try
+ if ts != IDLE and ts != INTRANS:
+ return
+
+ # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
+ if not self._withhold and ts == IDLE:
+ return
+
+ # if we didn't declare the cursor ourselves we still have to close it
+ # but we must make sure it exists.
+ if not self._described:
+ query = sql.SQL(
+ "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
+ ).format(sql.Literal(self._name))
+ res = yield from self._conn._exec_command(query)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
+ if res.ntuples == 0:
+ return
+
+ query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
+ yield from self._conn._exec_command(query)
+
+ def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+ # If we are stealing the cursor, make sure we know its shape
+ if not self._described:
+ yield from self._start_query()
+ yield from self._describe_gen()
+
+ query = sql.SQL("FETCH FORWARD {} FROM {}").format(
+ sql.SQL("ALL") if num is None else sql.Literal(num),
+ sql.Identifier(self._name),
+ )
+ res = yield from self._conn._exec_command(query, result_format=self._format)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
+
+ self.pgresult = res
+ self._tx.set_pgresult(res, set_loaders=False)
+ return self._tx.load_rows(0, res.ntuples, self._make_row)
+
+ def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
+ if mode not in ("relative", "absolute"):
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
+ query = sql.SQL("MOVE{} {} FROM {}").format(
+ sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
+ sql.Literal(value),
+ sql.Identifier(self._name),
+ )
+ yield from self._conn._exec_command(query)
+
+ def _make_declare_statement(self, query: Query) -> sql.Composed:
+
+ if isinstance(query, bytes):
+ query = query.decode(self._encoding)
+ if not isinstance(query, sql.Composable):
+ query = sql.SQL(query)
+
+ parts = [
+ sql.SQL("DECLARE"),
+ sql.Identifier(self._name),
+ ]
+ if self._scrollable is not None:
+ parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
+ parts.append(sql.SQL("CURSOR"))
+ if self._withhold:
+ parts.append(sql.SQL("WITH HOLD"))
+ parts.append(sql.SQL("FOR"))
+ parts.append(query)
+
+ return sql.SQL(" ").join(parts)
+
+
+class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="ServerCursor[Any]")
+
+ @overload
+ def __init__(
+ self: "ServerCursor[Row]",
+ connection: "Connection[Row]",
+ name: str,
+ *,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "ServerCursor[Row]",
+ connection: "Connection[Any]",
+ name: str,
+ *,
+ row_factory: RowFactory[Row],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "Connection[Any]",
+ name: str,
+ *,
+ row_factory: Optional[RowFactory[Row]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ Cursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
+
+ def __del__(self) -> None:
+ if not self.closed:
+ warn(
+ f"the server-side cursor {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ with self._conn.lock:
+ if self.closed:
+ return
+ if not self._conn.closed:
+ self._conn.wait(self._close_gen())
+ super().close()
+
+ def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> _Self:
+ """
+ Open a cursor to execute a query to the database.
+ """
+ if kwargs:
+ raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+ if self._pgconn.pipeline_status:
+ raise e.NotSupportedError(
+ "server-side cursors not supported in pipeline mode"
+ )
+
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._declare_gen(query, params, binary))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ return self
+
+ def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = True,
+ ) -> None:
+ """Method not implemented for server-side cursors."""
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
+
+ def fetchone(self) -> Optional[Row]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ def fetchmany(self, size: int = 0) -> List[Row]:
+ if not size:
+ size = self.arraysize
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(size))
+ self._pos += len(recs)
+ return recs
+
+ def fetchall(self) -> List[Row]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(None))
+ self._pos += len(recs)
+ return recs
+
+ def __iter__(self) -> Iterator[Row]:
+ while True:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(self.itersize))
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ def scroll(self, value: int, mode: str = "relative") -> None:
+ with self._conn.lock:
+ self._conn.wait(self._scroll_gen(value, mode))
+ # Postgres doesn't have a reliable way to report a cursor out of bound
+ if mode == "relative":
+ self._pos += value
+ else:
+ self._pos = value
+
+
+class AsyncServerCursor(
+ ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
+):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="AsyncServerCursor[Any]")
+
+ @overload
+ def __init__(
+ self: "AsyncServerCursor[Row]",
+ connection: "AsyncConnection[Row]",
+ name: str,
+ *,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncServerCursor[Row]",
+ connection: "AsyncConnection[Any]",
+ name: str,
+ *,
+ row_factory: AsyncRowFactory[Row],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "AsyncConnection[Any]",
+ name: str,
+ *,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ AsyncCursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
+
+ def __del__(self) -> None:
+ if not self.closed:
+ warn(
+ f"the server-side cursor {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ async def close(self) -> None:
+ async with self._conn.lock:
+ if self.closed:
+ return
+ if not self._conn.closed:
+ await self._conn.wait(self._close_gen())
+ await super().close()
+
+ async def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> _Self:
+ if kwargs:
+ raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+ if self._pgconn.pipeline_status:
+ raise e.NotSupportedError(
+ "server-side cursors not supported in pipeline mode"
+ )
+
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._declare_gen(query, params, binary))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ return self
+
+ async def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = True,
+ ) -> None:
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
+
+ async def fetchone(self) -> Optional[Row]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ async def fetchmany(self, size: int = 0) -> List[Row]:
+ if not size:
+ size = self.arraysize
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(size))
+ self._pos += len(recs)
+ return recs
+
+ async def fetchall(self) -> List[Row]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(None))
+ self._pos += len(recs)
+ return recs
+
+ async def __aiter__(self) -> AsyncIterator[Row]:
+ while True:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(self.itersize))
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ async def scroll(self, value: int, mode: str = "relative") -> None:
+ async with self._conn.lock:
+ await self._conn.wait(self._scroll_gen(value, mode))
diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py
new file mode 100644
index 0000000..099a01c
--- /dev/null
+++ b/psycopg/psycopg/sql.py
@@ -0,0 +1,467 @@
+"""
+SQL composition utility module
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+import string
+from abc import ABC, abstractmethod
+from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
+
+from .pq import Escaping
+from .abc import AdaptContext
+from .adapt import Transformer, PyFormat
+from ._compat import LiteralString
+from ._encodings import conn_encoding
+
+
+def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
+ """
+ Adapt a Python object to a quoted SQL string.
+
+ Use this function only if you absolutely want to convert a Python string to
+ an SQL quoted literal to use e.g. to generate batch SQL and you won't have
+ a connection available when you will need to use it.
+
+ This function is relatively inefficient, because it doesn't cache the
+ adaptation rules. If you pass a `!context` you can adapt the adaptation
+ rules used, otherwise only global rules are used.
+
+ """
+ return Literal(obj).as_string(context)
+
+
+class Composable(ABC):
+ """
+ Abstract base class for objects that can be used to compose an SQL string.
+
+ `!Composable` objects can be passed directly to
+ `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
+ `~psycopg.Cursor.copy()` in place of the query string.
+
+ `!Composable` objects can be joined using the ``+`` operator: the result
+ will be a `Composed` instance containing the objects joined. The operator
+ ``*`` is also supported with an integer argument: the result is a
+ `!Composed` instance containing the left argument repeated as many times as
+ requested.
+ """
+
+ def __init__(self, obj: Any):
+ self._obj = obj
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._obj!r})"
+
+ @abstractmethod
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ """
+ Return the value of the object as bytes.
+
+ :param context: the context to evaluate the object into.
+ :type context: `connection` or `cursor`
+
+ The method is automatically invoked by `~psycopg.Cursor.execute()`,
+ `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
+ `!Composable` is passed instead of the query string.
+
+ """
+ raise NotImplementedError
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ """
+ Return the value of the object as string.
+
+ :param context: the context to evaluate the string into.
+ :type context: `connection` or `cursor`
+
+ """
+ conn = context.connection if context else None
+ enc = conn_encoding(conn)
+ b = self.as_bytes(context)
+ if isinstance(b, bytes):
+ return b.decode(enc)
+ else:
+ # buffer object
+ return codecs.lookup(enc).decode(b)[0]
+
+ def __add__(self, other: "Composable") -> "Composed":
+ if isinstance(other, Composed):
+ return Composed([self]) + other
+ if isinstance(other, Composable):
+ return Composed([self]) + Composed([other])
+ else:
+ return NotImplemented
+
+ def __mul__(self, n: int) -> "Composed":
+ return Composed([self] * n)
+
+ def __eq__(self, other: Any) -> bool:
+ return type(self) is type(other) and self._obj == other._obj
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+class Composed(Composable):
+ """
+ A `Composable` object made of a sequence of `!Composable`.
+
+ The object is usually created using `!Composable` operators and methods.
+ However it is possible to create a `!Composed` directly specifying a
+ sequence of objects as arguments: if they are not `!Composable` they will
+ be wrapped in a `Literal`.
+
+ Example::
+
+ >>> comp = sql.Composed(
+ ... [sql.SQL("INSERT INTO "), sql.Identifier("table")])
+ >>> print(comp.as_string(conn))
+ INSERT INTO "table"
+
+ `!Composed` objects are iterable (so they can be used in `SQL.join` for
+ instance).
+ """
+
+ _obj: List[Composable]
+
+ def __init__(self, seq: Sequence[Any]):
+ seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
+ super().__init__(seq)
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ return b"".join(obj.as_bytes(context) for obj in self._obj)
+
+ def __iter__(self) -> Iterator[Composable]:
+ return iter(self._obj)
+
+ def __add__(self, other: Composable) -> "Composed":
+ if isinstance(other, Composed):
+ return Composed(self._obj + other._obj)
+ if isinstance(other, Composable):
+ return Composed(self._obj + [other])
+ else:
+ return NotImplemented
+
+ def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
+ """
+ Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
+
+ The `!joiner` must be a `SQL` or a string which will be interpreted as
+ an `SQL`.
+
+ Example::
+
+ >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
+ >>> print(fields.join(', ').as_string(conn))
+ "foo", "bar"
+
+ """
+ if isinstance(joiner, str):
+ joiner = SQL(joiner)
+ elif not isinstance(joiner, SQL):
+ raise TypeError(
+ "Composed.join() argument must be strings or SQL,"
+ f" got {joiner!r} instead"
+ )
+
+ return joiner.join(self._obj)
+
+
+class SQL(Composable):
+ """
+ A `Composable` representing a snippet of SQL statement.
+
+ `!SQL` exposes `join()` and `format()` methods useful to create a template
+ where to merge variable parts of a query (for instance field or table
+ names).
+
+ The `!obj` string doesn't undergo any form of escaping, so it is not
+ suitable to represent variable identifiers or values: you should only use
+ it to pass constant strings representing templates or snippets of SQL
+ statements; use other objects such as `Identifier` or `Literal` to
+ represent variable parts.
+
+ Example::
+
+ >>> query = sql.SQL("SELECT {0} FROM {1}").format(
+ ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
+ ... sql.Identifier('table'))
+ >>> print(query.as_string(conn))
+ SELECT "foo", "bar" FROM "table"
+ """
+
+ _obj: LiteralString
+ _formatter = string.Formatter()
+
+ def __init__(self, obj: LiteralString):
+ super().__init__(obj)
+ if not isinstance(obj, str):
+ raise TypeError(f"SQL values must be strings, got {obj!r} instead")
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ return self._obj
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ enc = "utf-8"
+ if context:
+ enc = conn_encoding(context.connection)
+ return self._obj.encode(enc)
+
+ def format(self, *args: Any, **kwargs: Any) -> Composed:
+ """
+ Merge `Composable` objects into a template.
+
+ :param args: parameters to replace to numbered (``{0}``, ``{1}``) or
+ auto-numbered (``{}``) placeholders
+ :param kwargs: parameters to replace to named (``{name}``) placeholders
+ :return: the union of the `!SQL` string with placeholders replaced
+ :rtype: `Composed`
+
+ The method is similar to the Python `str.format()` method: the string
+ template supports auto-numbered (``{}``), numbered (``{0}``,
+ ``{1}``...), and named placeholders (``{name}``), with positional
+ arguments replacing the numbered placeholders and keywords replacing
+ the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
+ are not supported.
+
+ If a `!Composable` objects is passed to the template it will be merged
+ according to its `as_string()` method. If any other Python object is
+ passed, it will be wrapped in a `Literal` object and so escaped
+ according to SQL rules.
+
+ Example::
+
+ >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
+ ... .format(sql.Identifier('people'), sql.Identifier('id'))
+ ... .as_string(conn))
+ SELECT * FROM "people" WHERE "id" = %s
+
+ >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
+ ... .format(tbl=sql.Identifier('people'), name="O'Rourke"))
+ ... .as_string(conn))
+ SELECT * FROM "people" WHERE name = 'O''Rourke'
+
+ """
+ rv: List[Composable] = []
+ autonum: Optional[int] = 0
+ # TODO: this is probably not the right way to whitelist pre
+ # pyre complains. Will wait for mypy to complain too to fix.
+ pre: LiteralString
+ for pre, name, spec, conv in self._formatter.parse(self._obj):
+ if spec:
+ raise ValueError("no format specification supported by SQL")
+ if conv:
+ raise ValueError("no format conversion supported by SQL")
+ if pre:
+ rv.append(SQL(pre))
+
+ if name is None:
+ continue
+
+ if name.isdigit():
+ if autonum:
+ raise ValueError(
+ "cannot switch from automatic field numbering to manual"
+ )
+ rv.append(args[int(name)])
+ autonum = None
+
+ elif not name:
+ if autonum is None:
+ raise ValueError(
+ "cannot switch from manual field numbering to automatic"
+ )
+ rv.append(args[autonum])
+ autonum += 1
+
+ else:
+ rv.append(kwargs[name])
+
+ return Composed(rv)
+
+ def join(self, seq: Iterable[Composable]) -> Composed:
+ """
+ Join a sequence of `Composable`.
+
+ :param seq: the elements to join.
+ :type seq: iterable of `!Composable`
+
+ Use the `!SQL` object's string to separate the elements in `!seq`.
+ Note that `Composed` objects are iterable too, so they can be used as
+ argument for this method.
+
+ Example::
+
+ >>> snip = sql.SQL(', ').join(
+ ... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
+ >>> print(snip.as_string(conn))
+ "foo", "bar", "baz"
+ """
+ rv = []
+ it = iter(seq)
+ try:
+ rv.append(next(it))
+ except StopIteration:
+ pass
+ else:
+ for i in it:
+ rv.append(self)
+ rv.append(i)
+
+ return Composed(rv)
+
+
+class Identifier(Composable):
+ """
+ A `Composable` representing an SQL identifier or a dot-separated sequence.
+
+ Identifiers usually represent names of database objects, such as tables or
+ fields. PostgreSQL identifiers follow `different rules`__ than SQL string
+ literals for escaping (e.g. they use double quotes instead of single).
+
+ .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
+ SQL-SYNTAX-IDENTIFIERS
+
+ Example::
+
+ >>> t1 = sql.Identifier("foo")
+ >>> t2 = sql.Identifier("ba'r")
+ >>> t3 = sql.Identifier('ba"z')
+ >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
+ "foo", "ba'r", "ba""z"
+
+ Multiple strings can be passed to the object to represent a qualified name,
+ i.e. a dot-separated sequence of identifiers.
+
+ Example::
+
+ >>> query = sql.SQL("SELECT {} FROM {}").format(
+ ... sql.Identifier("table", "field"),
+ ... sql.Identifier("schema", "table"))
+ >>> print(query.as_string(conn))
+ SELECT "table"."field" FROM "schema"."table"
+
+ """
+
+ _obj: Sequence[str]
+
+ def __init__(self, *strings: str):
+ # init super() now to make the __repr__ not explode in case of error
+ super().__init__(strings)
+
+ if not strings:
+ raise TypeError("Identifier cannot be empty")
+
+ for s in strings:
+ if not isinstance(s, str):
+ raise TypeError(
+ f"SQL identifier parts must be strings, got {s!r} instead"
+ )
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ if not conn:
+ raise ValueError("a connection is necessary for Identifier")
+ esc = Escaping(conn.pgconn)
+ enc = conn_encoding(conn)
+ escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+ return b".".join(escs)
+
+
+class Literal(Composable):
+ """
+ A `Composable` representing an SQL value to include in a query.
+
+ Usually you will want to include placeholders in the query and pass values
+ as `~cursor.execute()` arguments. If however you really really need to
+ include a literal value in the query you can use this object.
+
+ The string returned by `!as_string()` follows the normal :ref:`adaptation
+ rules <types-adaptation>` for Python objects.
+
+ Example::
+
+ >>> s1 = sql.Literal("fo'o")
+ >>> s2 = sql.Literal(42)
+ >>> s3 = sql.Literal(date(2000, 1, 1))
+ >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
+ 'fo''o', 42, '2000-01-01'::date
+
+ """
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ tx = Transformer.from_context(context)
+ return tx.as_literal(self._obj)
+
+
+class Placeholder(Composable):
+ """A `Composable` representing a placeholder for query parameters.
+
+ If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
+ ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
+ ``%b``).
+
+ The object is useful to generate SQL queries with a variable number of
+ arguments.
+
+ Examples::
+
+ >>> names = ['foo', 'bar', 'baz']
+
+ >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(sql.Placeholder() * len(names)))
+ >>> print(q1.as_string(conn))
+ INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)
+
+ >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(map(sql.Placeholder, names)))
+ >>> print(q2.as_string(conn))
+ INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)
+
+ """
+
+ def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO):
+ super().__init__(name)
+ if not isinstance(name, str):
+ raise TypeError(f"expected string as name, got {name!r}")
+
+ if ")" in name:
+ raise ValueError(f"invalid name: {name!r}")
+
+ if type(format) is str:
+ format = PyFormat(format)
+ if not isinstance(format, PyFormat):
+ raise TypeError(
+ f"expected PyFormat as format, got {type(format).__name__!r}"
+ )
+
+ self._format: PyFormat = format
+
+ def __repr__(self) -> str:
+ parts = []
+ if self._obj:
+ parts.append(repr(self._obj))
+ if self._format is not PyFormat.AUTO:
+ parts.append(f"format={self._format.name}")
+
+ return f"{self.__class__.__name__}({', '.join(parts)})"
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ code = self._format.value
+ return f"%({self._obj}){code}" if self._obj else f"%{code}"
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ enc = conn_encoding(conn)
+ return self.as_string(context).encode(enc)
+
+
+# Literals
+NULL = SQL("NULL")
+DEFAULT = SQL("DEFAULT")
diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py
new file mode 100644
index 0000000..e13486e
--- /dev/null
+++ b/psycopg/psycopg/transaction.py
@@ -0,0 +1,290 @@
+"""
+Transaction context managers returned by Connection.transaction()
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from types import TracebackType
+from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
+
+from . import pq
+from . import sql
+from . import errors as e
+from .abc import ConnectionType, PQGen
+
+if TYPE_CHECKING:
+ from typing import Any
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+
+IDLE = pq.TransactionStatus.IDLE
+
+OK = pq.ConnStatus.OK
+
+logger = logging.getLogger(__name__)
+
+
+class Rollback(Exception):
+ """
+ Exit the current `Transaction` context immediately and rollback any changes
+ made within this context.
+
+ If a transaction context is specified in the constructor, rollback
+ enclosing transactions contexts up to and including the one specified.
+ """
+
+ __module__ = "psycopg"
+
+ def __init__(
+ self,
+ transaction: Union["Transaction", "AsyncTransaction", None] = None,
+ ):
+ self.transaction = transaction
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__qualname__}({self.transaction!r})"
+
+
+class OutOfOrderTransactionNesting(e.ProgrammingError):
+ """Out-of-order transaction nesting detected"""
+
+
+class BaseTransaction(Generic[ConnectionType]):
+ def __init__(
+ self,
+ connection: ConnectionType,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ):
+ self._conn = connection
+ self.pgconn = self._conn.pgconn
+ self._savepoint_name = savepoint_name or ""
+ self.force_rollback = force_rollback
+ self._entered = self._exited = False
+ self._outer_transaction = False
+ self._stack_index = -1
+
+ @property
+ def savepoint_name(self) -> Optional[str]:
+ """
+ The name of the savepoint; `!None` if handling the main transaction.
+ """
+ # Yes, it may change on __enter__. No, I don't care, because the
+ # un-entered state is outside the public interface.
+ return self._savepoint_name
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self.pgconn)
+ if not self._entered:
+ status = "inactive"
+ elif not self._exited:
+ status = "active"
+ else:
+ status = "terminated"
+
+ sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
+ return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
+
+ def _enter_gen(self) -> PQGen[None]:
+ if self._entered:
+ raise TypeError("transaction blocks can be used only once")
+ self._entered = True
+
+ self._push_savepoint()
+ for command in self._get_enter_commands():
+ yield from self._conn._exec_command(command)
+
+ def _exit_gen(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> PQGen[bool]:
+ if not exc_val and not self.force_rollback:
+ yield from self._commit_gen()
+ return False
+ else:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ return (yield from self._rollback_gen(exc_val))
+ except OutOfOrderTransactionNesting:
+ # Clobber an exception happened in the block with the exception
+ # caused by out-of-order transaction detected, so make the
+ # behaviour consistent with _commit_gen and to make sure the
+ # user fixes this condition, which is unrelated from
+ # operational error that might arise in the block.
+ raise
+ except Exception as exc2:
+ logger.warning("error ignored in rollback of %s: %s", self, exc2)
+ return False
+
+ def _commit_gen(self) -> PQGen[None]:
+ ex = self._pop_savepoint("commit")
+ self._exited = True
+ if ex:
+ raise ex
+
+ for command in self._get_commit_commands():
+ yield from self._conn._exec_command(command)
+
+ def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
+ if isinstance(exc_val, Rollback):
+ logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
+
+ ex = self._pop_savepoint("rollback")
+ self._exited = True
+ if ex:
+ raise ex
+
+ for command in self._get_rollback_commands():
+ yield from self._conn._exec_command(command)
+
+ if isinstance(exc_val, Rollback):
+ if not exc_val.transaction or exc_val.transaction is self:
+ return True # Swallow the exception
+
+ return False
+
+ def _get_enter_commands(self) -> Iterator[bytes]:
+ if self._outer_transaction:
+ yield self._conn._get_tx_start_command()
+
+ if self._savepoint_name:
+ yield (
+ sql.SQL("SAVEPOINT {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ def _get_commit_commands(self) -> Iterator[bytes]:
+ if self._savepoint_name and not self._outer_transaction:
+ yield (
+ sql.SQL("RELEASE {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ if self._outer_transaction:
+ assert not self._conn._num_transactions
+ yield b"COMMIT"
+
+ def _get_rollback_commands(self) -> Iterator[bytes]:
+ if self._savepoint_name and not self._outer_transaction:
+ yield (
+ sql.SQL("ROLLBACK TO {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+ yield (
+ sql.SQL("RELEASE {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ if self._outer_transaction:
+ assert not self._conn._num_transactions
+ yield b"ROLLBACK"
+
+ # Also clear the prepared statements cache.
+ if self._conn._prepared.clear():
+ yield from self._conn._prepared.get_maintenance_commands()
+
+ def _push_savepoint(self) -> None:
+ """
+ Push the transaction on the connection transactions stack.
+
+ Also set the internal state of the object and verify consistency.
+ """
+ self._outer_transaction = self.pgconn.transaction_status == IDLE
+ if self._outer_transaction:
+ # outer transaction: if no name it's only a begin, else
+ # there will be an additional savepoint
+ assert not self._conn._num_transactions
+ else:
+ # inner transaction: it always has a name
+ if not self._savepoint_name:
+ self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
+
+ self._stack_index = self._conn._num_transactions
+ self._conn._num_transactions += 1
+
+ def _pop_savepoint(self, action: str) -> Optional[Exception]:
+ """
+ Pop the transaction from the connection transactions stack.
+
+ Also verify the state consistency.
+ """
+ self._conn._num_transactions -= 1
+ if self._conn._num_transactions == self._stack_index:
+ return None
+
+ return OutOfOrderTransactionNesting(
+ f"transaction {action} at the wrong nesting level: {self}"
+ )
+
+
+class Transaction(BaseTransaction["Connection[Any]"]):
+ """
+ Returned by `Connection.transaction()` to handle a transaction block.
+ """
+
+ __module__ = "psycopg"
+
+ _Self = TypeVar("_Self", bound="Transaction")
+
+ @property
+ def connection(self) -> "Connection[Any]":
+ """The connection the object is managing."""
+ return self._conn
+
+ def __enter__(self: _Self) -> _Self:
+ with self._conn.lock:
+ self._conn.wait(self._enter_gen())
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ if self.pgconn.status == OK:
+ with self._conn.lock:
+ return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
+ else:
+ return False
+
+
+class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
+ """
+ Returned by `AsyncConnection.transaction()` to handle a transaction block.
+ """
+
+ __module__ = "psycopg"
+
+ _Self = TypeVar("_Self", bound="AsyncTransaction")
+
+ @property
+ def connection(self) -> "AsyncConnection[Any]":
+ return self._conn
+
+ async def __aenter__(self: _Self) -> _Self:
+ async with self._conn.lock:
+ await self._conn.wait(self._enter_gen())
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ if self.pgconn.status == OK:
+ async with self._conn.lock:
+ return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
+ else:
+ return False
diff --git a/psycopg/psycopg/types/__init__.py b/psycopg/psycopg/types/__init__.py
new file mode 100644
index 0000000..bdddf05
--- /dev/null
+++ b/psycopg/psycopg/types/__init__.py
@@ -0,0 +1,11 @@
+"""
+psycopg types package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from .. import _typeinfo
+
+# Exposed here
+TypeInfo = _typeinfo.TypeInfo
+TypesRegistry = _typeinfo.TypesRegistry
diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py
new file mode 100644
index 0000000..e35c5e7
--- /dev/null
+++ b/psycopg/psycopg/types/array.py
@@ -0,0 +1,464 @@
+"""
+Adapters for arrays
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
+
+from .. import pq
+from .. import errors as e
+from .. import postgres
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._compat import cache, prod
+from .._struct import pack_len, unpack_len
+from .._cmodule import _psycopg
+from ..postgres import TEXT_OID, INVALID_OID
+from .._typeinfo import TypeInfo
+
+_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
+_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
+_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
+_struct_dim = struct.Struct("!II") # dim, lower bound
+_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
+_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
+
+TEXT_ARRAY_OID = postgres.types["text"].array_oid
+
+PY_TEXT = PyFormat.TEXT
+PQ_BINARY = pq.Format.BINARY
+
+
+class BaseListDumper(RecursiveDumper):
+ element_oid = 0
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ if cls is NoneType:
+ cls = list
+
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ if self.element_oid and context:
+ sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
+ self.sub_dumper = sdclass(NoneType, context)
+
+ def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
+ """
+ Find the first non-null element of an eventually nested list
+ """
+ items = list(self._flatiter(L, set()))
+ types = {type(item): item for item in items}
+ if not types:
+ return None
+
+ if len(types) == 1:
+ t, v = types.popitem()
+ else:
+ # More than one type in the list. It might be still good, as long
+ # as they dump with the same oid (e.g. IPv4Network, IPv6Network).
+ dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
+ oids = set(d.oid for d in dumpers)
+ if len(oids) == 1:
+ t, v = types.popitem()
+ else:
+ raise e.DataError(
+ "cannot dump lists of mixed types;"
+ f" got: {', '.join(sorted(t.__name__ for t in types))}"
+ )
+
+ # Checking for precise type. If the type is a subclass (e.g. Int4)
+ # we assume the user knows what type they are passing.
+ if t is not int:
+ return v
+
+ # If we got an int, let's see what is the biggest one in order to
+ # choose the smallest OID and allow Postgres to do the right cast.
+ imax: int = max(items)
+ imin: int = min(items)
+ if imin >= 0:
+ return imax
+ else:
+ return max(imax, -imin - 1)
+
+ def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
+ if id(L) in seen:
+ raise e.DataError("cannot dump a recursive list")
+
+ seen.add(id(L))
+
+ for item in L:
+ if type(item) is list:
+ yield from self._flatiter(item, seen)
+ elif item is not None:
+ yield item
+
+ return None
+
+ def _get_base_type_info(self, base_oid: int) -> TypeInfo:
+ """
+ Return info about the base type.
+
+ Return text info as fallback.
+ """
+ if base_oid:
+ info = self._tx.adapters.types.get(base_oid)
+ if info:
+ return info
+
+ return self._tx.adapters.types["text"]
+
+
+class ListDumper(BaseListDumper):
+
+ delimiter = b","
+
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return self.cls
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format))
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ # Empty lists can only be dumped as text if the type is unknown.
+ return self
+
+ sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+
+ # We consider an array of unknowns as unknown, so we can dump empty
+ # lists or lists containing only None elements.
+ if sd.oid != INVALID_OID:
+ info = self._get_base_type_info(sd.oid)
+ dumper.oid = info.array_oid or TEXT_ARRAY_OID
+ dumper.delimiter = info.delimiter.encode()
+ else:
+ dumper.oid = INVALID_OID
+
+ return dumper
+
+ # Double quotes and backslashes embedded in element values will be
+ # backslash-escaped.
+ _re_esc = re.compile(rb'(["\\])')
+
+ def dump(self, obj: List[Any]) -> bytes:
+ tokens: List[Buffer] = []
+ needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
+
+ def dump_list(obj: List[Any]) -> None:
+ if not obj:
+ tokens.append(b"{}")
+ return
+
+ tokens.append(b"{")
+ for item in obj:
+ if isinstance(item, list):
+ dump_list(item)
+ elif item is not None:
+ ad = self._dump_item(item)
+ if needs_quotes(ad):
+ if not isinstance(ad, bytes):
+ ad = bytes(ad)
+ ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"'
+ tokens.append(ad)
+ else:
+ tokens.append(b"NULL")
+
+ tokens.append(self.delimiter)
+
+ tokens[-1] = b"}"
+
+ dump_list(obj)
+
+ return b"".join(tokens)
+
+ def _dump_item(self, item: Any) -> Buffer:
+ if self.sub_dumper:
+ return self.sub_dumper.dump(item)
+ else:
+ return self._tx.get_dumper(item, PY_TEXT).dump(item)
+
+
+@cache
+def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
+ """Return a regexp to recognise when a value needs quotes
+
+ from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
+
+ The array output routine will put double quotes around element values if
+ they are empty strings, contain curly braces, delimiter characters,
+ double quotes, backslashes, or white space, or match the word NULL.
+ """
+ return re.compile(
+ rb"""(?xi)
+ ^$ # the empty string
+ | ["{}%s\\\s] # or a char to escape
+ | ^null$ # or the word NULL
+ """
+ % delimiter
+ )
+
+
+class ListBinaryDumper(BaseListDumper):
+
+ format = pq.Format.BINARY
+
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return (self.cls,)
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format))
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return ListDumper(self.cls, self._tx)
+
+ sd = self._tx.get_dumper(item, format.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ info = self._get_base_type_info(sd.oid)
+ dumper.oid = info.array_oid or TEXT_ARRAY_OID
+
+ return dumper
+
+ def dump(self, obj: List[Any]) -> bytes:
+ # Postgres won't take unknown for element oid: fall back on text
+ sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
+
+ if not obj:
+ return _pack_head(0, 0, sub_oid)
+
+ data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
+ dims: List[int] = []
+ hasnull = 0
+
+ def calc_dims(L: List[Any]) -> None:
+ if isinstance(L, self.cls):
+ if not L:
+ raise e.DataError("lists cannot contain empty lists")
+ dims.append(len(L))
+ calc_dims(L[0])
+
+ calc_dims(obj)
+
+ def dump_list(L: List[Any], dim: int) -> None:
+ nonlocal hasnull
+ if len(L) != dims[dim]:
+ raise e.DataError("nested lists have inconsistent lengths")
+
+ if dim == len(dims) - 1:
+ for item in L:
+ if item is not None:
+ # If we get here, the sub_dumper must have been set
+ ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
+ data.append(pack_len(len(ad)))
+ data.append(ad)
+ else:
+ hasnull = 1
+ data.append(b"\xff\xff\xff\xff")
+ else:
+ for item in L:
+ if not isinstance(item, self.cls):
+ raise e.DataError("nested lists have inconsistent depths")
+ dump_list(item, dim + 1) # type: ignore
+
+ dump_list(obj, 0)
+
+ data[0] = _pack_head(len(dims), hasnull, sub_oid)
+ data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
+ return b"".join(data)
+
+
+class ArrayLoader(RecursiveLoader):
+
+ delimiter = b","
+ base_oid: int
+
+ def load(self, data: Buffer) -> List[Any]:
+ loader = self._tx.get_loader(self.base_oid, self.format)
+ return _load_text(data, loader, self.delimiter)
+
+
+class ArrayBinaryLoader(RecursiveLoader):
+
+ format = pq.Format.BINARY
+
+ def load(self, data: Buffer) -> List[Any]:
+ return _load_binary(data, self._tx)
+
+
+def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ if not info.array_oid:
+ raise ValueError(f"the type info {info} doesn't describe an array")
+
+ base: Type[Any]
+ adapters = context.adapters if context else postgres.adapters
+
+ base = getattr(_psycopg, "ArrayLoader", ArrayLoader)
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "base_oid": info.oid,
+ "delimiter": info.delimiter.encode(),
+ }
+ loader = type(name, (base,), attribs)
+ adapters.register_loader(info.array_oid, loader)
+
+ loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader)
+ adapters.register_loader(info.array_oid, loader)
+
+ base = ListDumper
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "oid": info.array_oid,
+ "element_oid": info.oid,
+ "delimiter": info.delimiter.encode(),
+ }
+ dumper = type(name, (base,), attribs)
+ adapters.register_dumper(None, dumper)
+
+ base = ListBinaryDumper
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "oid": info.array_oid,
+ "element_oid": info.oid,
+ }
+ dumper = type(name, (base,), attribs)
+ adapters.register_dumper(None, dumper)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ # The text dumper is more flexible as it can handle lists of mixed type,
+ # so register it later.
+ context.adapters.register_dumper(list, ListBinaryDumper)
+ context.adapters.register_dumper(list, ListDumper)
+
+
+def register_all_arrays(context: AdaptContext) -> None:
+ """
+ Associate the array oid of all the types in Loader.globals.
+
+ This function is designed to be called once at import time, after having
+ registered all the base loaders.
+ """
+ for t in context.adapters.types:
+ if t.array_oid:
+ t.register(context)
+
+
+def _load_text(
+ data: Buffer,
+ loader: Loader,
+ delimiter: bytes = b",",
+ __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
+) -> List[Any]:
+ rv = None
+ stack: List[Any] = []
+ a: List[Any] = []
+ rv = a
+ load = loader.load
+
+ # Remove the dimensions information prefix (``[...]=``)
+ if data and data[0] == b"["[0]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ idx = data.find(b"=")
+ if idx == -1:
+ raise e.DataError("malformed array: no '=' after dimension information")
+ data = data[idx + 1 :]
+
+ re_parse = _get_array_parse_regexp(delimiter)
+ for m in re_parse.finditer(data):
+ t = m.group(1)
+ if t == b"{":
+ if stack:
+ stack[-1].append(a)
+ stack.append(a)
+ a = []
+
+ elif t == b"}":
+ if not stack:
+ raise e.DataError("malformed array: unexpected '}'")
+ rv = stack.pop()
+
+ else:
+ if not stack:
+ wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
+ raise e.DataError(f"malformed array: unexpected '{wat}'")
+ if t == b"NULL":
+ v = None
+ else:
+ if t.startswith(b'"'):
+ t = __re_unescape.sub(rb"\1", t[1:-1])
+ v = load(t)
+
+ stack[-1].append(v)
+
+ assert rv is not None
+ return rv
+
+
+@cache
+def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
+ """
+ Return a regexp to tokenize an array representation into item and brackets
+ """
+ return re.compile(
+ rb"""(?xi)
+ ( [{}] # open or closed bracket
+ | " (?: [^"\\] | \\. )* " # or a quoted string
+ | [^"{}%s\\]+ # or an unquoted non-empty string
+ ) ,?
+ """
+ % delimiter
+ )
+
+
+def _load_binary(data: Buffer, tx: Transformer) -> List[Any]:
+ ndims, hasnull, oid = _unpack_head(data)
+ load = tx.get_loader(oid, PQ_BINARY).load
+
+ if not ndims:
+ return []
+
+ p = 12 + 8 * ndims
+ dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
+ nelems = prod(dims)
+
+ out: List[Any] = [None] * nelems
+ for i in range(nelems):
+ size = unpack_len(data, p)[0]
+ p += 4
+ if size == -1:
+ continue
+ out[i] = load(data[p : p + size])
+ p += size
+
+ # fon ndims > 1 we have to aggregate the array into sub-arrays
+ for dim in dims[-1:0:-1]:
+ out = [out[i : i + dim] for i in range(0, len(out), dim)]
+
+ return out
diff --git a/psycopg/psycopg/types/bool.py b/psycopg/psycopg/types/bool.py
new file mode 100644
index 0000000..db7e181
--- /dev/null
+++ b/psycopg/psycopg/types/bool.py
@@ -0,0 +1,51 @@
+"""
+Adapters for booleans.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+
+class BoolDumper(Dumper):
+
+ oid = postgres.types["bool"].oid
+
+ def dump(self, obj: bool) -> bytes:
+ return b"t" if obj else b"f"
+
+ def quote(self, obj: bool) -> bytes:
+ return b"true" if obj else b"false"
+
+
+class BoolBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["bool"].oid
+
+ def dump(self, obj: bool) -> bytes:
+ return b"\x01" if obj else b"\x00"
+
+
+class BoolLoader(Loader):
+ def load(self, data: Buffer) -> bool:
+ return data == b"t"
+
+
+class BoolBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> bool:
+ return data != b"\x00"
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(bool, BoolDumper)
+ adapters.register_dumper(bool, BoolBinaryDumper)
+ adapters.register_loader("bool", BoolLoader)
+ adapters.register_loader("bool", BoolBinaryLoader)
diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py
new file mode 100644
index 0000000..1c609c3
--- /dev/null
+++ b/psycopg/psycopg/types/composite.py
@@ -0,0 +1,290 @@
+"""
+Support for composite types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from collections import namedtuple
+from typing import Any, Callable, cast, Iterator, List, Optional
+from typing import Sequence, Tuple, Type
+
+from .. import pq
+from .. import postgres
+from ..abc import AdaptContext, Buffer
+from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader
+from .._struct import pack_len, unpack_len
+from ..postgres import TEXT_OID
+from .._typeinfo import CompositeInfo as CompositeInfo # exported here
+from .._encodings import _as_python_identifier
+
+_struct_oidlen = struct.Struct("!Ii")
+_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
+_unpack_oidlen = cast(
+ Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
+)
+
+
+class SequenceDumper(RecursiveDumper):
+ def _dump_sequence(
+ self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
+ ) -> bytes:
+ if not obj:
+ return start + end
+
+ parts: List[Buffer] = [start]
+
+ for item in obj:
+ if item is None:
+ parts.append(sep)
+ continue
+
+ dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
+ ad = dumper.dump(item)
+ if not ad:
+ ad = b'""'
+ elif self._re_needs_quotes.search(ad):
+ ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"'
+
+ parts.append(ad)
+ parts.append(sep)
+
+ parts[-1] = end
+
+ return b"".join(parts)
+
+ _re_needs_quotes = re.compile(rb'[",\\\s()]')
+ _re_esc = re.compile(rb"([\\\"])")
+
+
+class TupleDumper(SequenceDumper):
+
+ # Should be this, but it doesn't work
+ # oid = postgres_types["record"].oid
+
+ def dump(self, obj: Tuple[Any, ...]) -> bytes:
+ return self._dump_sequence(obj, b"(", b")", b",")
+
+
+class TupleBinaryDumper(RecursiveDumper):
+
+ format = pq.Format.BINARY
+
+ # Subclasses must set an info
+ info: CompositeInfo
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ nfields = len(self.info.field_types)
+ self._tx.set_dumper_types(self.info.field_types, self.format)
+ self._formats = (PyFormat.from_pq(self.format),) * nfields
+
+ def dump(self, obj: Tuple[Any, ...]) -> bytearray:
+ out = bytearray(pack_len(len(obj)))
+ adapted = self._tx.dump_sequence(obj, self._formats)
+ for i in range(len(obj)):
+ b = adapted[i]
+ oid = self.info.field_types[i]
+ if b is not None:
+ out += _pack_oidlen(oid, len(b))
+ out += b
+ else:
+ out += _pack_oidlen(oid, -1)
+
+ return out
+
+
+class BaseCompositeLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer(context)
+
+ def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]:
+ """
+ Split a non-empty representation of a composite type into components.
+
+ Terminators shouldn't be used in `!data` (so that both record and range
+ representations can be parsed).
+ """
+ for m in self._re_tokenize.finditer(data):
+ if m.group(1):
+ yield None
+ elif m.group(2) is not None:
+ yield self._re_undouble.sub(rb"\1", m.group(2))
+ else:
+ yield m.group(3)
+
+ # If the final group ended in `,` there is a final NULL in the record
+ # that the regexp couldn't parse.
+ if m and m.group().endswith(b","):
+ yield None
+
+ _re_tokenize = re.compile(
+ rb"""(?x)
+ (,) # an empty token, representing NULL
+ | " ((?: [^"] | "")*) " ,? # or a quoted string
+ | ([^",)]+) ,? # or an unquoted string
+ """
+ )
+
+ _re_undouble = re.compile(rb'(["\\])\1')
+
+
+class RecordLoader(BaseCompositeLoader):
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
+ if data == b"()":
+ return ()
+
+ cast = self._tx.get_loader(TEXT_OID, self.format).load
+ return tuple(
+ cast(token) if token is not None else None
+ for token in self._parse_record(data[1:-1])
+ )
+
+
+class RecordBinaryLoader(Loader):
+ format = pq.Format.BINARY
+ _types_set = False
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer(context)
+
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ return self._tx.load_sequence(
+ tuple(
+ data[offset : offset + length] if length != -1 else None
+ for _, offset, length in self._walk_record(data)
+ )
+ )
+
+ def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]:
+ """
+ Yield a sequence of (oid, offset, length) for the content of the record
+ """
+ nfields = unpack_len(data, 0)[0]
+ i = 4
+ for _ in range(nfields):
+ oid, length = _unpack_oidlen(data, i)
+ yield oid, i + 8, length
+ i += (8 + length) if length > 0 else 8
+
+ def _config_types(self, data: Buffer) -> None:
+ oids = [r[0] for r in self._walk_record(data)]
+ self._tx.set_loader_types(oids, self.format)
+
+
+class CompositeLoader(RecordLoader):
+
+ factory: Callable[..., Any]
+ fields_types: List[int]
+ _types_set = False
+
+ def load(self, data: Buffer) -> Any:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ if data == b"()":
+ return type(self).factory()
+
+ return type(self).factory(
+ *self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
+ )
+
+ def _config_types(self, data: Buffer) -> None:
+ self._tx.set_loader_types(self.fields_types, self.format)
+
+
+class CompositeBinaryLoader(RecordBinaryLoader):
+
+ format = pq.Format.BINARY
+ factory: Callable[..., Any]
+
+ def load(self, data: Buffer) -> Any:
+ r = super().load(data)
+ return type(self).factory(*r)
+
+
+def register_composite(
+ info: CompositeInfo,
+ context: Optional[AdaptContext] = None,
+ factory: Optional[Callable[..., Any]] = None,
+) -> None:
+ """Register the adapters to load and dump a composite type.
+
+ :param info: The object with the information about the composite to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+ :param factory: Callable to convert the sequence of attributes read from
+ the composite into a Python object.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested composite available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ if not factory:
+ factory = namedtuple( # type: ignore
+ _as_python_identifier(info.name),
+ [_as_python_identifier(n) for n in info.field_names],
+ )
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[BaseCompositeLoader] = type(
+ f"{info.name.title()}Loader",
+ (CompositeLoader,),
+ {
+ "factory": factory,
+ "fields_types": info.field_types,
+ },
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ loader = type(
+ f"{info.name.title()}BinaryLoader",
+ (CompositeBinaryLoader,),
+ {"factory": factory},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # If the factory is a type, create and register dumpers for it
+ if isinstance(factory, type):
+ dumper = type(
+ f"{info.name.title()}BinaryDumper",
+ (TupleBinaryDumper,),
+ {"oid": info.oid, "info": info},
+ )
+ adapters.register_dumper(factory, dumper)
+
+ # Default to the text dumper because it is more flexible
+ dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid})
+ adapters.register_dumper(factory, dumper)
+
+ info.python_type = factory
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(tuple, TupleDumper)
+ adapters.register_loader("record", RecordLoader)
+ adapters.register_loader("record", RecordBinaryLoader)
diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py
new file mode 100644
index 0000000..f0dfe83
--- /dev/null
+++ b/psycopg/psycopg/types/datetime.py
@@ -0,0 +1,754 @@
+"""
+Adapters for date/time types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from datetime import date, datetime, time, timedelta, timezone
+from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format
+from .._tz import get_tzinfo
+from ..abc import AdaptContext, DumperKey
+from ..adapt import Buffer, Dumper, Loader, PyFormat
+from ..errors import InterfaceError, DataError
+from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
+
+if TYPE_CHECKING:
+ from ..connection import BaseConnection
+
+_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset
+_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
+_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack)
+
+_struct_interval = struct.Struct("!qii") # microseconds, days, months
+_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
+_unpack_interval = cast(
+ Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack
+)
+
+utc = timezone.utc
+_pg_date_epoch_days = date(2000, 1, 1).toordinal()
+_pg_datetime_epoch = datetime(2000, 1, 1)
+_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc)
+_py_date_min_days = date.min.toordinal()
+
+
+class DateDumper(Dumper):
+
+ oid = postgres.types["date"].oid
+
+ def dump(self, obj: date) -> bytes:
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ return str(obj).encode()
+
+
+class DateBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["date"].oid
+
+ def dump(self, obj: date) -> bytes:
+ days = obj.toordinal() - _pg_date_epoch_days
+ return pack_int4(days)
+
+
+class _BaseTimeDumper(Dumper):
+ def get_key(self, obj: time, format: PyFormat) -> DumperKey:
+ # Use (cls,) to report the need to upgrade to a dumper for timetz (the
+ # Frankenstein of the data types).
+ if not obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ raise NotImplementedError
+
+
+class _BaseTimeTextDumper(_BaseTimeDumper):
+ def dump(self, obj: time) -> bytes:
+ return str(obj).encode()
+
+
+class TimeDumper(_BaseTimeTextDumper):
+
+ oid = postgres.types["time"].oid
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzDumper(self.cls)
+
+
+class TimeTzDumper(_BaseTimeTextDumper):
+
+ oid = postgres.types["timetz"].oid
+
+
+class TimeBinaryDumper(_BaseTimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["time"].oid
+
+ def dump(self, obj: time) -> bytes:
+ us = obj.microsecond + 1_000_000 * (
+ obj.second + 60 * (obj.minute + 60 * obj.hour)
+ )
+ return pack_int8(us)
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzBinaryDumper(self.cls)
+
+
+class TimeTzBinaryDumper(_BaseTimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timetz"].oid
+
+ def dump(self, obj: time) -> bytes:
+ us = obj.microsecond + 1_000_000 * (
+ obj.second + 60 * (obj.minute + 60 * obj.hour)
+ )
+ off = obj.utcoffset()
+ assert off is not None
+ return _pack_timetz(us, -int(off.total_seconds()))
+
+
+class _BaseDatetimeDumper(Dumper):
+ def get_key(self, obj: datetime, format: PyFormat) -> DumperKey:
+ # Use (cls,) to report the need to upgrade (downgrade, actually) to a
+ # dumper for naive timestamp.
+ if obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ raise NotImplementedError
+
+
+class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
+ def dump(self, obj: datetime) -> bytes:
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ return str(obj).encode()
+
+
+class DatetimeDumper(_BaseDatetimeTextDumper):
+
+ oid = postgres.types["timestamptz"].oid
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzDumper(self.cls)
+
+
+class DatetimeNoTzDumper(_BaseDatetimeTextDumper):
+
+ oid = postgres.types["timestamp"].oid
+
+
+class DatetimeBinaryDumper(_BaseDatetimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timestamptz"].oid
+
+ def dump(self, obj: datetime) -> bytes:
+ delta = obj - _pg_datetimetz_epoch
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
+ return pack_int8(micros)
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzBinaryDumper(self.cls)
+
+
+class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timestamp"].oid
+
+ def dump(self, obj: datetime) -> bytes:
+ delta = obj - _pg_datetime_epoch
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
+ return pack_int8(micros)
+
+
+class TimedeltaDumper(Dumper):
+
+ oid = postgres.types["interval"].oid
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ if self.connection:
+ if (
+ self.connection.pgconn.parameter_status(b"IntervalStyle")
+ == b"sql_standard"
+ ):
+ setattr(self, "dump", self._dump_sql)
+
+ def dump(self, obj: timedelta) -> bytes:
+ # The comma is parsed ok by PostgreSQL but it's not documented
+ # and it seems brittle to rely on it. CRDB doesn't consume it well.
+ return str(obj).encode().replace(b",", b"")
+
+ def _dump_sql(self, obj: timedelta) -> bytes:
+ # sql_standard format needs explicit signs
+ # otherwise -1 day 1 sec will mean -1 sec
+ return b"%+d day %+d second %+d microsecond" % (
+ obj.days,
+ obj.seconds,
+ obj.microseconds,
+ )
+
+
+class TimedeltaBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["interval"].oid
+
+ def dump(self, obj: timedelta) -> bytes:
+ micros = 1_000_000 * obj.seconds + obj.microseconds
+ return _pack_interval(micros, obj.days, 0)
+
+
+class DateLoader(Loader):
+
+ _ORDER_YMD = 0
+ _ORDER_DMY = 1
+ _ORDER_MDY = 2
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ ds = _get_datestyle(self.connection)
+ if ds.startswith(b"I"): # ISO
+ self._order = self._ORDER_YMD
+ elif ds.startswith(b"G"): # German
+ self._order = self._ORDER_DMY
+ elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ def load(self, data: Buffer) -> date:
+ if self._order == self._ORDER_YMD:
+ ye = data[:4]
+ mo = data[5:7]
+ da = data[8:]
+ elif self._order == self._ORDER_DMY:
+ da = data[:2]
+ mo = data[3:5]
+ ye = data[6:]
+ else:
+ mo = data[:2]
+ da = data[3:5]
+ ye = data[6:]
+
+ try:
+ return date(int(ye), int(mo), int(da))
+ except ValueError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ if s == "infinity" or (s and len(s.split()[0]) > 10):
+ raise DataError(f"date too large (after year 10K): {s!r}") from None
+ elif s == "-infinity" or "BC" in s:
+ raise DataError(f"date too small (before year 1): {s!r}") from None
+ else:
+ raise DataError(f"can't parse date {s!r}: {ex}") from None
+
+
+class DateBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> date:
+ days = unpack_int4(data)[0] + _pg_date_epoch_days
+ try:
+ return date.fromordinal(days)
+ except (ValueError, OverflowError):
+ if days < _py_date_min_days:
+ raise DataError("date too small (before year 1)") from None
+ else:
+ raise DataError("date too large (after year 10K)") from None
+
+
+class TimeLoader(Loader):
+
+ _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?")
+
+ def load(self, data: Buffer) -> time:
+ m = self._re_format.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse time {s!r}")
+
+ ho, mi, se, fr = m.groups()
+
+ # Pad the fraction of second to get micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ try:
+ return time(int(ho), int(mi), int(se), us)
+ except ValueError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse time {s!r}: {e}") from None
+
+
+class TimeBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> time:
+ val = unpack_int8(data)[0]
+ val, us = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+ try:
+ return time(h, m, s, us)
+ except ValueError:
+ raise DataError(f"time not supported by Python: hour={h}") from None
+
+
+class TimetzLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros
+ ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
+ $
+ """
+ )
+
+ def load(self, data: Buffer) -> time:
+ m = self._re_format.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse timetz {s!r}")
+
+ ho, mi, se, fr, sgn, oh, om, os = m.groups()
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ # Calculate timezone
+ off = 60 * 60 * int(oh)
+ if om:
+ off += 60 * int(om)
+ if os:
+ off += int(os)
+ tz = timezone(timedelta(0, off if sgn == b"+" else -off))
+
+ try:
+ return time(int(ho), int(mi), int(se), us, tz)
+ except ValueError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse timetz {s!r}: {e}") from None
+
+
+class TimetzBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> time:
+ val, off = _unpack_timetz(data)
+
+ val, us = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+
+ try:
+ return time(h, m, s, us, timezone(timedelta(seconds=-off)))
+ except ValueError:
+ raise DataError(f"time not supported by Python: hour={h}") from None
+
+
+class TimestampLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
+ (?: T | [^a-z0-9] ) # Separator, including T
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ $
+ """
+ )
+ _re_format_pg = re.compile(
+ rb"""(?ix)
+ ^
+ [a-z]+ [^a-z0-9] # DoW, separator
+ (\d+|[a-z]+) [^a-z0-9] # Month or day
+ (\d+|[a-z]+) [^a-z0-9] # Month or day
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ [^a-z0-9] (\d+) # Year
+ $
+ """
+ )
+
+ _ORDER_YMD = 0
+ _ORDER_DMY = 1
+ _ORDER_MDY = 2
+ _ORDER_PGDM = 3
+ _ORDER_PGMD = 4
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+
+ ds = _get_datestyle(self.connection)
+ if ds.startswith(b"I"): # ISO
+ self._order = self._ORDER_YMD
+ elif ds.startswith(b"G"): # German
+ self._order = self._ORDER_DMY
+ elif ds.startswith(b"S"): # SQL
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
+ elif ds.startswith(b"P"): # Postgres
+ self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
+ self._re_format = self._re_format_pg
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ def load(self, data: Buffer) -> datetime:
+ m = self._re_format.match(data)
+ if not m:
+ raise _get_timestamp_load_error(self.connection, data) from None
+
+ if self._order == self._ORDER_YMD:
+ ye, mo, da, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ elif self._order == self._ORDER_DMY:
+ da, mo, ye, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ elif self._order == self._ORDER_MDY:
+ mo, da, ye, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ else:
+ if self._order == self._ORDER_PGDM:
+ da, mo, ho, mi, se, fr, ye = m.groups()
+ else:
+ mo, da, ho, mi, se, fr, ye = m.groups()
+ try:
+ imo = _month_abbr[mo]
+ except KeyError:
+ s = mo.decode("utf8", "replace")
+ raise DataError(f"can't parse month: {s!r}") from None
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ try:
+ return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us)
+ except ValueError as ex:
+ raise _get_timestamp_load_error(self.connection, data, ex) from None
+
+
+class TimestampBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> datetime:
+ micros = unpack_int8(data)[0]
+ try:
+ return _pg_datetime_epoch + timedelta(microseconds=micros)
+ except OverflowError:
+ if micros <= 0:
+ raise DataError("timestamp too small (before year 1)") from None
+ else:
+ raise DataError("timestamp too large (after year 10K)") from None
+
+
+class TimestamptzLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
+ (?: T | [^a-z0-9] ) # Separator, including T
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
+ $
+ """
+ )
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
+
+ ds = _get_datestyle(self.connection)
+ if not ds.startswith(b"I"): # not ISO
+ setattr(self, "load", self._load_notimpl)
+
+ def load(self, data: Buffer) -> datetime:
+ m = self._re_format.match(data)
+ if not m:
+ raise _get_timestamp_load_error(self.connection, data) from None
+
+ ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups()
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ # Calculate timezone offset
+ soff = 60 * 60 * int(oh)
+ if om:
+ soff += 60 * int(om)
+ if os:
+ soff += int(os)
+ tzoff = timedelta(0, soff if sgn == b"+" else -soff)
+
+ # The return value is a datetime with the timezone of the connection
+ # (in order to be consistent with the binary loader, which is the only
+ # thing it can return). So create a temporary datetime object, in utc,
+ # shift it by the offset parsed from the timestamp, and then move it to
+ # the connection timezone.
+ dt = None
+ ex: Exception
+ try:
+ dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc)
+ return (dt - tzoff).astimezone(self._timezone)
+ except OverflowError as e:
+ # If we have created the temporary 'dt' it means that we have a
+ # datetime close to max, the shift pushed it past max, overflowing.
+ # In this case return the datetime in a fixed offset timezone.
+ if dt is not None:
+ return dt.replace(tzinfo=timezone(tzoff))
+ else:
+ ex = e
+ except ValueError as e:
+ ex = e
+
+ raise _get_timestamp_load_error(self.connection, data, ex) from None
+
+ def _load_notimpl(self, data: Buffer) -> datetime:
+ s = bytes(data).decode("utf8", "replace")
+ ds = _get_datestyle(self.connection).decode("ascii")
+ raise NotImplementedError(
+ f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
+ )
+
+
+class TimestamptzBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
+
+ def load(self, data: Buffer) -> datetime:
+ micros = unpack_int8(data)[0]
+ try:
+ ts = _pg_datetimetz_epoch + timedelta(microseconds=micros)
+ return ts.astimezone(self._timezone)
+ except OverflowError:
+ # If we were asked about a timestamp which would overflow in UTC,
+ # but not in the desired timezone (e.g. datetime.max at Chicago
+ # timezone) we can still save the day by shifting the value by the
+ # timezone offset and then replacing the timezone.
+ if self._timezone:
+ utcoff = self._timezone.utcoffset(
+ datetime.min if micros < 0 else datetime.max
+ )
+ if utcoff:
+ usoff = 1_000_000 * int(utcoff.total_seconds())
+ try:
+ ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
+ except OverflowError:
+ pass # will raise downstream
+ else:
+ return ts.replace(tzinfo=self._timezone)
+
+ if micros <= 0:
+ raise DataError("timestamp too small (before year 1)") from None
+ else:
+ raise DataError("timestamp too large (after year 10K)") from None
+
+
+class IntervalLoader(Loader):
+
+ _re_interval = re.compile(
+ rb"""
+ (?: ([-+]?\d+) \s+ years? \s* )? # Years
+ (?: ([-+]?\d+) \s+ mons? \s* )? # Months
+ (?: ([-+]?\d+) \s+ days? \s* )? # Days
+ (?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time
+ )?
+ """,
+ re.VERBOSE,
+ )
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ if self.connection:
+ ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
+ if ints != b"postgres":
+ setattr(self, "load", self._load_notimpl)
+
+ def load(self, data: Buffer) -> timedelta:
+ m = self._re_interval.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse interval {s!r}")
+
+ ye, mo, da, sgn, ho, mi, se = m.groups()
+ days = 0
+ seconds = 0.0
+
+ if ye:
+ days += 365 * int(ye)
+ if mo:
+ days += 30 * int(mo)
+ if da:
+ days += int(da)
+
+ if ho:
+ seconds = 3600 * int(ho) + 60 * int(mi) + float(se)
+ if sgn == b"-":
+ seconds = -seconds
+
+ try:
+ return timedelta(days=days, seconds=seconds)
+ except OverflowError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse interval {s!r}: {e}") from None
+
+ def _load_notimpl(self, data: Buffer) -> timedelta:
+ s = bytes(data).decode("utf8", "replace")
+ ints = (
+ self.connection
+ and self.connection.pgconn.parameter_status(b"IntervalStyle")
+ or b"unknown"
+ ).decode("utf8", "replace")
+ raise NotImplementedError(
+ f"can't parse interval with IntervalStyle {ints}: {s!r}"
+ )
+
+
+class IntervalBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> timedelta:
+ micros, days, months = _unpack_interval(data)
+ if months > 0:
+ years, months = divmod(months, 12)
+ days = days + 30 * months + 365 * years
+ elif months < 0:
+ years, months = divmod(-months, 12)
+ days = days - 30 * months - 365 * years
+
+ try:
+ return timedelta(days=days, microseconds=micros)
+ except OverflowError as e:
+ raise DataError(f"can't parse interval: {e}") from None
+
+
+def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
+ if conn:
+ ds = conn.pgconn.parameter_status(b"DateStyle")
+ if ds:
+ return ds
+
+ return b"ISO, DMY"
+
+
+def _get_timestamp_load_error(
+ conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None
+) -> Exception:
+ s = bytes(data).decode("utf8", "replace")
+
+ def is_overflow(s: str) -> bool:
+ if not s:
+ return False
+
+ ds = _get_datestyle(conn)
+ if not ds.startswith(b"P"): # Postgres
+ return len(s.split()[0]) > 10 # date is first token
+ else:
+ return len(s.split()[-1]) > 4 # year is last token
+
+ if s == "-infinity" or s.endswith("BC"):
+ return DataError("timestamp too small (before year 1): {s!r}")
+ elif s == "infinity" or is_overflow(s):
+ return DataError(f"timestamp too large (after year 10K): {s!r}")
+ else:
+ return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}")
+
+
+_month_abbr = {
+ n: i
+ for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1)
+}
+
+# Pad to get microseconds from a fraction of seconds
+_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1]
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("datetime.date", DateDumper)
+ adapters.register_dumper("datetime.date", DateBinaryDumper)
+
+ # first register dumpers for 'timetz' oid, then the proper ones on time type.
+ adapters.register_dumper("datetime.time", TimeTzDumper)
+ adapters.register_dumper("datetime.time", TimeTzBinaryDumper)
+ adapters.register_dumper("datetime.time", TimeDumper)
+ adapters.register_dumper("datetime.time", TimeBinaryDumper)
+
+ # first register dumpers for 'timestamp' oid, then the proper ones
+ # on the datetime type.
+ adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper)
+
+ adapters.register_dumper("datetime.timedelta", TimedeltaDumper)
+ adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper)
+
+ adapters.register_loader("date", DateLoader)
+ adapters.register_loader("date", DateBinaryLoader)
+ adapters.register_loader("time", TimeLoader)
+ adapters.register_loader("time", TimeBinaryLoader)
+ adapters.register_loader("timetz", TimetzLoader)
+ adapters.register_loader("timetz", TimetzBinaryLoader)
+ adapters.register_loader("timestamp", TimestampLoader)
+ adapters.register_loader("timestamp", TimestampBinaryLoader)
+ adapters.register_loader("timestamptz", TimestamptzLoader)
+ adapters.register_loader("timestamptz", TimestamptzBinaryLoader)
+ adapters.register_loader("interval", IntervalLoader)
+ adapters.register_loader("interval", IntervalBinaryLoader)
diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py
new file mode 100644
index 0000000..d3c7387
--- /dev/null
+++ b/psycopg/psycopg/types/enum.py
@@ -0,0 +1,177 @@
+"""
+Adapters for the enum type.
+"""
+from enum import Enum
+from typing import Any, Dict, Generic, Optional, Mapping, Sequence
+from typing import Tuple, Type, TypeVar, Union, cast
+from typing_extensions import TypeAlias
+
+from .. import postgres
+from .. import errors as e
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+from .._encodings import conn_encoding
+from .._typeinfo import EnumInfo as EnumInfo # exported here
+
+E = TypeVar("E", bound=Enum)
+
+EnumDumpMap: TypeAlias = Dict[E, bytes]
+EnumLoadMap: TypeAlias = Dict[bytes, E]
+EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
+
+
+class _BaseEnumLoader(Loader, Generic[E]):
+ """
+ Loader for a specific Enum class
+ """
+
+ enum: Type[E]
+ _load_map: EnumLoadMap[E]
+
+ def load(self, data: Buffer) -> E:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+
+ try:
+ return self._load_map[data]
+ except KeyError:
+ enc = conn_encoding(self.connection)
+ label = data.decode(enc, "replace")
+ raise e.DataError(
+ f"bad member for enum {self.enum.__qualname__}: {label!r}"
+ )
+
+
+class _BaseEnumDumper(Dumper, Generic[E]):
+ """
+ Dumper for a specific Enum class
+ """
+
+ enum: Type[E]
+ _dump_map: EnumDumpMap[E]
+
+ def dump(self, value: E) -> Buffer:
+ return self._dump_map[value]
+
+
+class EnumDumper(Dumper):
+ """
+ Dumper for a generic Enum class
+ """
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._encoding = conn_encoding(self.connection)
+
+ def dump(self, value: E) -> Buffer:
+ return value.name.encode(self._encoding)
+
+
+class EnumBinaryDumper(EnumDumper):
+ format = Format.BINARY
+
+
+def register_enum(
+ info: EnumInfo,
+ context: Optional[AdaptContext] = None,
+ enum: Optional[Type[E]] = None,
+ *,
+ mapping: EnumMapping[E] = None,
+) -> None:
+ """Register the adapters to load and dump a enum type.
+
+ :param info: The object with the information about the enum to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+ :param enum: Python enum type matching to the PostgreSQL one. If `!None`,
+ a new enum will be generated and exposed as `EnumInfo.enum`.
+ :param mapping: Override the mapping between `!enum` members and `!info`
+ labels.
+ """
+
+ if not info:
+ raise TypeError("no info passed. Is the requested enum available?")
+
+ if enum is None:
+ enum = cast(Type[E], Enum(info.name.title(), info.labels, module=__name__))
+
+ info.enum = enum
+ adapters = context.adapters if context else postgres.adapters
+ info.register(context)
+
+ load_map = _make_load_map(info, enum, mapping, context)
+ attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map}
+
+ name = f"{info.name.title()}Loader"
+ loader = type(name, (_BaseEnumLoader,), attribs)
+ adapters.register_loader(info.oid, loader)
+
+ name = f"{info.name.title()}BinaryLoader"
+ loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY})
+ adapters.register_loader(info.oid, loader)
+
+ dump_map = _make_dump_map(info, enum, mapping, context)
+ attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map}
+
+ name = f"{enum.__name__}Dumper"
+ dumper = type(name, (_BaseEnumDumper,), attribs)
+ adapters.register_dumper(info.enum, dumper)
+
+ name = f"{enum.__name__}BinaryDumper"
+ dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY})
+ adapters.register_dumper(info.enum, dumper)
+
+
+def _make_load_map(
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumLoadMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumLoadMap[E] = {}
+ for label in info.labels:
+ try:
+ member = enum[label]
+ except KeyError:
+ # tolerate a missing enum, assuming it won't be used. If it is we
+ # will get a DataError on fetch.
+ pass
+ else:
+ rv[label.encode(enc)] = member
+
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[label.encode(enc)] = member
+
+ return rv
+
+
+def _make_dump_map(
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumDumpMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumDumpMap[E] = {}
+ for member in enum:
+ rv[member] = member.name.encode(enc)
+
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[member] = label.encode(enc)
+
+ return rv
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, EnumBinaryDumper)
+ context.adapters.register_dumper(Enum, EnumDumper)
diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py
new file mode 100644
index 0000000..e1ab1d5
--- /dev/null
+++ b/psycopg/psycopg/types/hstore.py
@@ -0,0 +1,131 @@
+"""
+Dict to hstore adaptation
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+from typing import Dict, List, Optional
+from typing_extensions import TypeAlias
+
+from .. import errors as e
+from .. import postgres
+from ..abc import Buffer, AdaptContext
+from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
+from ..postgres import TEXT_OID
+from .._typeinfo import TypeInfo
+
+_re_escape = re.compile(r'(["\\])')
+_re_unescape = re.compile(r"\\(.)")
+
+_re_hstore = re.compile(
+ r"""
+ # hstore key:
+ # a string of normal or escaped chars
+ "((?: [^"\\] | \\. )*)"
+ \s*=>\s* # hstore value
+ (?:
+ NULL # the value can be null - not caught
+ # or a quoted string like the key
+ | "((?: [^"\\] | \\. )*)"
+ )
+ (?:\s*,\s*|$) # pairs separated by comma or end of string.
+""",
+ re.VERBOSE,
+)
+
+
+Hstore: TypeAlias = Dict[str, Optional[str]]
+
+
+class BaseHstoreDumper(RecursiveDumper):
+ def dump(self, obj: Hstore) -> Buffer:
+ if not obj:
+ return b""
+
+ tokens: List[str] = []
+
+ def add_token(s: str) -> None:
+ tokens.append('"')
+ tokens.append(_re_escape.sub(r"\\\1", s))
+ tokens.append('"')
+
+ for k, v in obj.items():
+
+ if not isinstance(k, str):
+ raise e.DataError("hstore keys can only be strings")
+ add_token(k)
+
+ tokens.append("=>")
+
+ if v is None:
+ tokens.append("NULL")
+ elif not isinstance(v, str):
+ raise e.DataError("hstore keys can only be strings")
+ else:
+ add_token(v)
+
+ tokens.append(",")
+
+ del tokens[-1]
+ data = "".join(tokens)
+ dumper = self._tx.get_dumper(data, PyFormat.TEXT)
+ return dumper.dump(data)
+
+
+class HstoreLoader(RecursiveLoader):
+ def load(self, data: Buffer) -> Hstore:
+ loader = self._tx.get_loader(TEXT_OID, self.format)
+ s: str = loader.load(data)
+
+ rv: Hstore = {}
+ start = 0
+ for m in _re_hstore.finditer(s):
+ if m is None or m.start() != start:
+ raise e.DataError(f"error parsing hstore pair at char {start}")
+ k = _re_unescape.sub(r"\1", m.group(1))
+ v = m.group(2)
+ if v is not None:
+ v = _re_unescape.sub(r"\1", v)
+
+ rv[k] = v
+ start = m.end()
+
+ if start < len(s):
+ raise e.DataError(f"error parsing hstore: unparsed data after char {start}")
+
+ return rv
+
+
+def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register the adapters to load and dump hstore.
+
+ :param info: The object with the information about the hstore type.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the 'hstore' extension loaded?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # Generate and register a customized text dumper
+ class HstoreDumper(BaseHstoreDumper):
+ oid = info.oid
+
+ adapters.register_dumper(dict, HstoreDumper)
+
+ # register the text loader on the oid
+ adapters.register_loader(info.oid, HstoreLoader)
diff --git a/psycopg/psycopg/types/json.py b/psycopg/psycopg/types/json.py
new file mode 100644
index 0000000..a80e0e4
--- /dev/null
+++ b/psycopg/psycopg/types/json.py
@@ -0,0 +1,232 @@
+"""
+Adapers for JSON types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import json
+from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
+
+from .. import abc
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap
+from ..errors import DataError
+
+JsonDumpsFunction = Callable[[Any], str]
+JsonLoadsFunction = Callable[[Union[str, bytes]], Any]
+
+
+def set_json_dumps(
+ dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None
+) -> None:
+ """
+ Set the JSON serialisation function to store JSON objects in the database.
+
+ :param dumps: The dump function to use.
+ :type dumps: `!Callable[[Any], str]`
+ :param context: Where to use the `!dumps` function. If not specified, use it
+ globally.
+ :type context: `~psycopg.Connection` or `~psycopg.Cursor`
+
+ By default dumping JSON uses the builtin `json.dumps`. You can override
+ it to use a different JSON library or to use customised arguments.
+
+ If the `Json` wrapper specified a `!dumps` function, use it in precedence
+ of the one set by this function.
+ """
+ if context is None:
+ # If changing load function globally, just change the default on the
+ # global class
+ _JsonDumper._dumps = dumps
+ else:
+ adapters = context.adapters
+
+ # If the scope is smaller than global, create subclassess and register
+ # them in the appropriate scope.
+ grid = [
+ (Json, PyFormat.BINARY),
+ (Json, PyFormat.TEXT),
+ (Jsonb, PyFormat.BINARY),
+ (Jsonb, PyFormat.TEXT),
+ ]
+ dumper: Type[_JsonDumper]
+ for wrapper, format in grid:
+ base = _get_current_dumper(adapters, wrapper, format)
+ name = base.__name__
+ if not base.__name__.startswith("Custom"):
+ name = f"Custom{name}"
+ dumper = type(name, (base,), {"_dumps": dumps})
+ adapters.register_dumper(wrapper, dumper)
+
+
+def set_json_loads(
+ loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None
+) -> None:
+ """
+ Set the JSON parsing function to fetch JSON objects from the database.
+
+ :param loads: The load function to use.
+ :type loads: `!Callable[[bytes], Any]`
+ :param context: Where to use the `!loads` function. If not specified, use
+ it globally.
+ :type context: `~psycopg.Connection` or `~psycopg.Cursor`
+
+ By default loading JSON uses the builtin `json.loads`. You can override
+ it to use a different JSON library or to use customised arguments.
+ """
+ if context is None:
+ # If changing load function globally, just change the default on the
+ # global class
+ _JsonLoader._loads = loads
+ else:
+ # If the scope is smaller than global, create subclassess and register
+ # them in the appropriate scope.
+ grid = [
+ ("json", JsonLoader),
+ ("json", JsonBinaryLoader),
+ ("jsonb", JsonbLoader),
+ ("jsonb", JsonbBinaryLoader),
+ ]
+ loader: Type[_JsonLoader]
+ for tname, base in grid:
+ loader = type(f"Custom{base.__name__}", (base,), {"_loads": loads})
+ context.adapters.register_loader(tname, loader)
+
+
+class _JsonWrapper:
+ __slots__ = ("obj", "dumps")
+
+ def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None):
+ self.obj = obj
+ self.dumps = dumps
+
+ def __repr__(self) -> str:
+ sobj = repr(self.obj)
+ if len(sobj) > 40:
+ sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
+ return f"{self.__class__.__name__}({sobj})"
+
+
+class Json(_JsonWrapper):
+ __slots__ = ()
+
+
+class Jsonb(_JsonWrapper):
+ __slots__ = ()
+
+
+class _JsonDumper(Dumper):
+
+ # The globally used JSON dumps() function. It can be changed globally (by
+ # set_json_dumps) or by a subclass.
+ _dumps: JsonDumpsFunction = json.dumps
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ super().__init__(cls, context)
+ self.dumps = self.__class__._dumps
+
+ def dump(self, obj: _JsonWrapper) -> bytes:
+ dumps = obj.dumps or self.dumps
+ return dumps(obj.obj).encode()
+
+
+class JsonDumper(_JsonDumper):
+
+ oid = postgres.types["json"].oid
+
+
+class JsonBinaryDumper(_JsonDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["json"].oid
+
+
+class JsonbDumper(_JsonDumper):
+
+ oid = postgres.types["jsonb"].oid
+
+
+class JsonbBinaryDumper(_JsonDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["jsonb"].oid
+
+ def dump(self, obj: _JsonWrapper) -> bytes:
+ dumps = obj.dumps or self.dumps
+ return b"\x01" + dumps(obj.obj).encode()
+
+
+class _JsonLoader(Loader):
+
+ # The globally used JSON loads() function. It can be changed globally (by
+ # set_json_loads) or by a subclass.
+ _loads: JsonLoadsFunction = json.loads
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ super().__init__(oid, context)
+ self.loads = self.__class__._loads
+
+ def load(self, data: Buffer) -> Any:
+ # json.loads() cannot work on memoryview.
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return self.loads(data)
+
+
+class JsonLoader(_JsonLoader):
+ pass
+
+
+class JsonbLoader(_JsonLoader):
+ pass
+
+
+class JsonBinaryLoader(_JsonLoader):
+ format = Format.BINARY
+
+
+class JsonbBinaryLoader(_JsonLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Any:
+ if data and data[0] != 1:
+ raise DataError("unknown jsonb binary format: {data[0]}")
+ data = data[1:]
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return self.loads(data)
+
+
+def _get_current_dumper(
+ adapters: AdaptersMap, cls: type, format: PyFormat
+) -> Type[abc.Dumper]:
+ try:
+ return adapters.get_dumper(cls, format)
+ except e.ProgrammingError:
+ return _default_dumpers[cls, format]
+
+
+_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = {
+ (Json, PyFormat.BINARY): JsonBinaryDumper,
+ (Json, PyFormat.TEXT): JsonDumper,
+ (Jsonb, PyFormat.BINARY): JsonbBinaryDumper,
+ (Jsonb, PyFormat.TEXT): JsonDumper,
+}
+
+
+def register_default_adapters(context: abc.AdaptContext) -> None:
+ adapters = context.adapters
+
+ # Currently json binary format is nothing different than text, maybe with
+ # an extra memcopy we can avoid.
+ adapters.register_dumper(Json, JsonBinaryDumper)
+ adapters.register_dumper(Json, JsonDumper)
+ adapters.register_dumper(Jsonb, JsonbBinaryDumper)
+ adapters.register_dumper(Jsonb, JsonbDumper)
+ adapters.register_loader("json", JsonLoader)
+ adapters.register_loader("jsonb", JsonbLoader)
+ adapters.register_loader("json", JsonBinaryLoader)
+ adapters.register_loader("jsonb", JsonbBinaryLoader)
diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py
new file mode 100644
index 0000000..3eaa7f1
--- /dev/null
+++ b/psycopg/psycopg/types/multirange.py
@@ -0,0 +1,514 @@
+"""
+Support for multirange types adaptation.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from decimal import Decimal
+from typing import Any, Generic, List, Iterable
+from typing import MutableSequence, Optional, Type, Union, overload
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
+
+from .range import Range, T, load_range_text, load_range_binary
+from .range import dump_range_text, dump_range_binary, fail_dump
+
+
+class Multirange(MutableSequence[Range[T]]):
+ """Python representation for a PostgreSQL multirange type.
+
+ :param items: Sequence of ranges to initialise the object.
+ """
+
+ def __init__(self, items: Iterable[Range[T]] = ()):
+ self._ranges: List[Range[T]] = list(map(self._check_type, items))
+
+ def _check_type(self, item: Any) -> Range[Any]:
+ if not isinstance(item, Range):
+ raise TypeError(
+ f"Multirange is a sequence of Range, got {type(item).__name__}"
+ )
+ return item
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._ranges!r})"
+
+ def __str__(self) -> str:
+ return f"{{{', '.join(map(str, self._ranges))}}}"
+
+ @overload
+ def __getitem__(self, index: int) -> Range[T]:
+ ...
+
+ @overload
+ def __getitem__(self, index: slice) -> "Multirange[T]":
+ ...
+
+ def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
+ if isinstance(index, int):
+ return self._ranges[index]
+ else:
+ return Multirange(self._ranges[index])
+
+ def __len__(self) -> int:
+ return len(self._ranges)
+
+ @overload
+ def __setitem__(self, index: int, value: Range[T]) -> None:
+ ...
+
+ @overload
+ def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
+ ...
+
+ def __setitem__(
+ self,
+ index: Union[int, slice],
+ value: Union[Range[T], Iterable[Range[T]]],
+ ) -> None:
+ if isinstance(index, int):
+ self._check_type(value)
+ self._ranges[index] = self._check_type(value)
+ elif not isinstance(value, Iterable):
+ raise TypeError("can only assign an iterable")
+ else:
+ value = map(self._check_type, value)
+ self._ranges[index] = value
+
+ def __delitem__(self, index: Union[int, slice]) -> None:
+ del self._ranges[index]
+
+ def insert(self, index: int, value: Range[T]) -> None:
+ self._ranges.insert(index, self._check_type(value))
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return False
+ return self._ranges == other._ranges
+
+ # Order is arbitrary but consistent
+
+ def __lt__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return NotImplemented
+ return self._ranges < other._ranges
+
+ def __le__(self, other: Any) -> bool:
+ return self == other or self < other # type: ignore
+
+ def __gt__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return NotImplemented
+ return self._ranges > other._ranges
+
+ def __ge__(self, other: Any) -> bool:
+ return self == other or self > other # type: ignore
+
+
+# Subclasses to specify a specific subtype. Usually not needed
+
+
+class Int4Multirange(Multirange[int]):
+ pass
+
+
+class Int8Multirange(Multirange[int]):
+ pass
+
+
+class NumericMultirange(Multirange[Decimal]):
+ pass
+
+
+class DateMultirange(Multirange[date]):
+ pass
+
+
+class TimestampMultirange(Multirange[datetime]):
+ pass
+
+
+class TimestamptzMultirange(Multirange[datetime]):
+ pass
+
+
+class BaseMultirangeDumper(RecursiveDumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ self._adapt_format = PyFormat.from_pq(self.format)
+
+ def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self.cls
+
+ item = self._get_item(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, self._adapt_format)
+ return (self.cls, sd.get_key(item, format))
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self
+
+ item = self._get_item(obj)
+ if item is None:
+ return MultirangeDumper(self.cls)
+
+ dumper: BaseMultirangeDumper
+ if type(item) is int:
+ # postgres won't cast int4range -> int8range so we must use
+ # text format and unknown oid here
+ sd = self._tx.get_dumper(item, PyFormat.TEXT)
+ dumper = MultirangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = INVALID_OID
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
+ # Work around the normal mapping where text is dumped as unknown
+ dumper.oid = self._get_multirange_oid(TEXT_OID)
+ else:
+ dumper.oid = self._get_multirange_oid(sd.oid)
+
+ return dumper
+
+ def _get_item(self, obj: Multirange[Any]) -> Any:
+ """
+ Return a member representative of the multirange
+ """
+ for r in obj:
+ if r.lower is not None:
+ return r.lower
+ if r.upper is not None:
+ return r.upper
+ return None
+
+ def _get_multirange_oid(self, sub_oid: int) -> int:
+ """
+ Return the oid of the range from the oid of its elements.
+ """
+ info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
+ return info.oid if info else INVALID_OID
+
+
+class MultirangeDumper(BaseMultirangeDumper):
+ """
+ Dumper for multirange types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ if not obj:
+ return b"{}"
+
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out: List[Buffer] = [b"{"]
+ for r in obj:
+ out.append(dump_range_text(r, dump))
+ out.append(b",")
+ out[-1] = b"}"
+ return b"".join(out)
+
+
+class MultirangeBinaryDumper(BaseMultirangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out: List[Buffer] = [pack_len(len(obj))]
+ for r in obj:
+ data = dump_range_binary(r, dump)
+ out.append(pack_len(len(data)))
+ out.append(data)
+ return b"".join(out)
+
+
+class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
+
+ subtype_oid: int
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
+
+
+class MultirangeLoader(BaseMultirangeLoader[T]):
+ def load(self, data: Buffer) -> Multirange[T]:
+ if not data or data[0] != _START_INT:
+ raise e.DataError(
+ "malformed multirange starting with"
+ f" {bytes(data[:1]).decode('utf8', 'replace')}"
+ )
+
+ out = Multirange[T]()
+ if data == b"{}":
+ return out
+
+ pos = 1
+ data = data[pos:]
+ try:
+ while True:
+ r, pos = load_range_text(data, self._load)
+ out.append(r)
+
+ sep = data[pos] # can raise IndexError
+ if sep == _SEP_INT:
+ data = data[pos + 1 :]
+ continue
+ elif sep == _END_INT:
+ if len(data) == pos + 1:
+ return out
+ else:
+ raise e.DataError(
+ "malformed multirange: data after closing brace"
+ )
+ else:
+ raise e.DataError(
+ f"malformed multirange: found unexpected {chr(sep)}"
+ )
+
+ except IndexError:
+ raise e.DataError("malformed multirange: separator missing")
+
+ return out
+
+
+_SEP_INT = ord(",")
+_START_INT = ord("{")
+_END_INT = ord("}")
+
+
+class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Multirange[T]:
+ nelems = unpack_len(data, 0)[0]
+ pos = 4
+ out = Multirange[T]()
+ for i in range(nelems):
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ out.append(load_range_binary(data[pos : pos + length], self._load))
+ pos += length
+
+ if pos != len(data):
+ raise e.DataError("unexpected trailing data in multirange")
+
+ return out
+
+
+def register_multirange(
+ info: MultirangeInfo, context: Optional[AdaptContext] = None
+) -> None:
+ """Register the adapters to load and dump a multirange type.
+
+ :param info: The object with the information about the range to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ Register loaders so that loading data of this type will result in a `Range`
+ with bounds parsed as the right subtype.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested multirange available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[MultirangeLoader[Any]] = type(
+ f"{info.name.title()}Loader",
+ (MultirangeLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ bloader: Type[MultirangeBinaryLoader[Any]] = type(
+ f"{info.name.title()}BinaryLoader",
+ (MultirangeBinaryLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, bloader)
+
+
+# Text dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Binary dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Text loaders for builtin multirange types
+
+
+class Int4MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeLoader(MultirangeLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeLoader(MultirangeLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin multirange types
+
+
+class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Multirange, MultirangeBinaryDumper)
+ adapters.register_dumper(Multirange, MultirangeDumper)
+ adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeDumper)
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
+ adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
+ adapters.register_loader("int4multirange", Int4MultirangeLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeLoader)
+ adapters.register_loader("datemultirange", DateMultirangeLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
+ adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
+ adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py
new file mode 100644
index 0000000..2f2c05b
--- /dev/null
+++ b/psycopg/psycopg/types/net.py
@@ -0,0 +1,206 @@
+"""
+Adapters for network types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Callable, Optional, Type, Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+if TYPE_CHECKING:
+ import ipaddress
+
+Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"]
+Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"]
+Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"]
+
+# These objects will be imported lazily
+ip_address: Callable[[str], Address] = None # type: ignore[assignment]
+ip_interface: Callable[[str], Interface] = None # type: ignore[assignment]
+ip_network: Callable[[str], Network] = None # type: ignore[assignment]
+IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment]
+IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment]
+IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment]
+IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment]
+IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment]
+IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment]
+
+PGSQL_AF_INET = 2
+PGSQL_AF_INET6 = 3
+IPV4_PREFIXLEN = 32
+IPV6_PREFIXLEN = 128
+
+
+class _LazyIpaddress:
+ def _ensure_module(self) -> None:
+ global ip_address, ip_interface, ip_network
+ global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
+ global IPv4Network, IPv6Network
+
+ if ip_address is None:
+ from ipaddress import ip_address, ip_interface, ip_network
+ from ipaddress import IPv4Address, IPv6Address
+ from ipaddress import IPv4Interface, IPv6Interface
+ from ipaddress import IPv4Network, IPv6Network
+
+
+class InterfaceDumper(Dumper):
+
+ oid = postgres.types["inet"].oid
+
+ def dump(self, obj: Interface) -> bytes:
+ return str(obj).encode()
+
+
+class NetworkDumper(Dumper):
+
+ oid = postgres.types["cidr"].oid
+
+ def dump(self, obj: Network) -> bytes:
+ return str(obj).encode()
+
+
+class _AIBinaryDumper(Dumper):
+ format = Format.BINARY
+ oid = postgres.types["inet"].oid
+
+
+class AddressBinaryDumper(_AIBinaryDumper):
+ def dump(self, obj: Address) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.max_prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class InterfaceBinaryDumper(_AIBinaryDumper):
+ def dump(self, obj: Interface) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.network.prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress):
+ """Either an address or an interface to inet
+
+ Used when looking up by oid.
+ """
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._ensure_module()
+
+ def dump(self, obj: Union[Address, Interface]) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ if isinstance(obj, (IPv4Interface, IPv6Interface)):
+ prefixlen = obj.network.prefixlen
+ else:
+ prefixlen = obj.max_prefixlen
+
+ head = bytes((family, prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class NetworkBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["cidr"].oid
+
+ def dump(self, obj: Network) -> bytes:
+ packed = obj.network_address.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.prefixlen, 1, len(packed)))
+ return head + packed
+
+
+class _LazyIpaddressLoader(Loader, _LazyIpaddress):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._ensure_module()
+
+
+class InetLoader(_LazyIpaddressLoader):
+ def load(self, data: Buffer) -> Union[Address, Interface]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ if b"/" in data:
+ return ip_interface(data.decode())
+ else:
+ return ip_address(data.decode())
+
+
+class InetBinaryLoader(_LazyIpaddressLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Union[Address, Interface]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ prefix = data[1]
+ packed = data[4:]
+ if data[0] == PGSQL_AF_INET:
+ if prefix == IPV4_PREFIXLEN:
+ return IPv4Address(packed)
+ else:
+ return IPv4Interface((packed, prefix))
+ else:
+ if prefix == IPV6_PREFIXLEN:
+ return IPv6Address(packed)
+ else:
+ return IPv6Interface((packed, prefix))
+
+
+class CidrLoader(_LazyIpaddressLoader):
+ def load(self, data: Buffer) -> Network:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ return ip_network(data.decode())
+
+
+class CidrBinaryLoader(_LazyIpaddressLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Network:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ prefix = data[1]
+ packed = data[4:]
+ if data[0] == PGSQL_AF_INET:
+ return IPv4Network((packed, prefix))
+ else:
+ return IPv6Network((packed, prefix))
+
+ return ip_network(data.decode())
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper)
+ adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper)
+ adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper)
+ adapters.register_dumper(None, InetBinaryDumper)
+ adapters.register_loader("inet", InetLoader)
+ adapters.register_loader("inet", InetBinaryLoader)
+ adapters.register_loader("cidr", CidrLoader)
+ adapters.register_loader("cidr", CidrBinaryLoader)
diff --git a/psycopg/psycopg/types/none.py b/psycopg/psycopg/types/none.py
new file mode 100644
index 0000000..2ab857c
--- /dev/null
+++ b/psycopg/psycopg/types/none.py
@@ -0,0 +1,25 @@
+"""
+Adapters for None.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from ..abc import AdaptContext, NoneType
+from ..adapt import Dumper
+
+
+class NoneDumper(Dumper):
+ """
+ Not a complete dumper as it doesn't implement dump(), but it implements
+ quote(), so it can be used in sql composition.
+ """
+
+ def dump(self, obj: None) -> bytes:
+ raise NotImplementedError("NULL is passed to Postgres in other ways")
+
+ def quote(self, obj: None) -> bytes:
+ return b"NULL"
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(NoneType, NoneDumper)
diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py
new file mode 100644
index 0000000..1bd9329
--- /dev/null
+++ b/psycopg/psycopg/types/numeric.py
@@ -0,0 +1,515 @@
+"""
+Adapers for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import struct
+from math import log
+from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
+from decimal import Decimal, DefaultContext, Context
+
+from .. import postgres
+from .. import errors as e
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader, PyFormat
+from .._struct import pack_int2, pack_uint2, unpack_int2
+from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
+from .._struct import pack_int8, unpack_int8
+from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
+
+# Exposed here
+from .._wrappers import (
+ Int2 as Int2,
+ Int4 as Int4,
+ Int8 as Int8,
+ IntNumeric as IntNumeric,
+ Oid as Oid,
+ Float4 as Float4,
+ Float8 as Float8,
+)
+
+
+class _IntDumper(Dumper):
+ def dump(self, obj: Any) -> Buffer:
+ t = type(obj)
+ if t is not int:
+ # Convert to int in order to dump IntEnum correctly
+ if issubclass(t, int):
+ obj = int(obj)
+ else:
+ raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
+
+ return str(obj).encode()
+
+ def quote(self, obj: Any) -> Buffer:
+ value = self.dump(obj)
+ return value if obj >= 0 else b" " + value
+
+
+class _SpecialValuesDumper(Dumper):
+
+ _special: Dict[bytes, bytes] = {}
+
+ def dump(self, obj: Any) -> bytes:
+ return str(obj).encode()
+
+ def quote(self, obj: Any) -> bytes:
+ value = self.dump(obj)
+
+ if value in self._special:
+ return self._special[value]
+
+ return value if obj >= 0 else b" " + value
+
+
+class FloatDumper(_SpecialValuesDumper):
+
+ oid = postgres.types["float8"].oid
+
+ _special = {
+ b"inf": b"'Infinity'::float8",
+ b"-inf": b"'-Infinity'::float8",
+ b"nan": b"'NaN'::float8",
+ }
+
+
+class Float4Dumper(FloatDumper):
+ oid = postgres.types["float4"].oid
+
+
+class FloatBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["float8"].oid
+
+ def dump(self, obj: float) -> bytes:
+ return pack_float8(obj)
+
+
+class Float4BinaryDumper(FloatBinaryDumper):
+
+ oid = postgres.types["float4"].oid
+
+ def dump(self, obj: float) -> bytes:
+ return pack_float4(obj)
+
+
+class DecimalDumper(_SpecialValuesDumper):
+
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Decimal) -> bytes:
+ if obj.is_nan():
+ # cover NaN and sNaN
+ return b"NaN"
+ else:
+ return str(obj).encode()
+
+ _special = {
+ b"Infinity": b"'Infinity'::numeric",
+ b"-Infinity": b"'-Infinity'::numeric",
+ b"NaN": b"'NaN'::numeric",
+ }
+
+
+class Int2Dumper(_IntDumper):
+ oid = postgres.types["int2"].oid
+
+
+class Int4Dumper(_IntDumper):
+ oid = postgres.types["int4"].oid
+
+
+class Int8Dumper(_IntDumper):
+ oid = postgres.types["int8"].oid
+
+
+class IntNumericDumper(_IntDumper):
+ oid = postgres.types["numeric"].oid
+
+
+class OidDumper(_IntDumper):
+ oid = postgres.types["oid"].oid
+
+
+class IntDumper(Dumper):
+ def dump(self, obj: Any) -> bytes:
+ raise TypeError(
+ f"{type(self).__name__} is a dispatcher to other dumpers:"
+ " dump() is not supposed to be called"
+ )
+
+ def get_key(self, obj: int, format: PyFormat) -> type:
+ return self.upgrade(obj, format).cls
+
+ _int2_dumper = Int2Dumper(Int2)
+ _int4_dumper = Int4Dumper(Int4)
+ _int8_dumper = Int8Dumper(Int8)
+ _int_numeric_dumper = IntNumericDumper(IntNumeric)
+
+ def upgrade(self, obj: int, format: PyFormat) -> Dumper:
+ if -(2**31) <= obj < 2**31:
+ if -(2**15) <= obj < 2**15:
+ return self._int2_dumper
+ else:
+ return self._int4_dumper
+ else:
+ if -(2**63) <= obj < 2**63:
+ return self._int8_dumper
+ else:
+ return self._int_numeric_dumper
+
+
+class Int2BinaryDumper(Int2Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int2(obj)
+
+
+class Int4BinaryDumper(Int4Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int4(obj)
+
+
+class Int8BinaryDumper(Int8Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int8(obj)
+
+
+# Ratio between number of bits required to store a number and number of pg
+# decimal digits required.
+BIT_PER_PGDIGIT = log(2) / log(10_000)
+
+
+class IntNumericBinaryDumper(IntNumericDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> Buffer:
+ return dump_int_to_numeric_binary(obj)
+
+
+class OidBinaryDumper(OidDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_uint4(obj)
+
+
+class IntBinaryDumper(IntDumper):
+
+ format = Format.BINARY
+
+ _int2_dumper = Int2BinaryDumper(Int2)
+ _int4_dumper = Int4BinaryDumper(Int4)
+ _int8_dumper = Int8BinaryDumper(Int8)
+ _int_numeric_dumper = IntNumericBinaryDumper(IntNumeric)
+
+
+class IntLoader(Loader):
+ def load(self, data: Buffer) -> int:
+ # it supports bytes directly
+ return int(data)
+
+
+class Int2BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int2(data)[0]
+
+
+class Int4BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int4(data)[0]
+
+
+class Int8BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int8(data)[0]
+
+
+class OidBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_uint4(data)[0]
+
+
+class FloatLoader(Loader):
+ def load(self, data: Buffer) -> float:
+ # it supports bytes directly
+ return float(data)
+
+
+class Float4BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> float:
+ return unpack_float4(data)[0]
+
+
+class Float8BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> float:
+ return unpack_float8(data)[0]
+
+
+class NumericLoader(Loader):
+ def load(self, data: Buffer) -> Decimal:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return Decimal(data.decode())
+
+
+DEC_DIGITS = 4 # decimal digits per Postgres "digit"
+NUMERIC_POS = 0x0000
+NUMERIC_NEG = 0x4000
+NUMERIC_NAN = 0xC000
+NUMERIC_PINF = 0xD000
+NUMERIC_NINF = 0xF000
+
+_decimal_special = {
+ NUMERIC_NAN: Decimal("NaN"),
+ NUMERIC_PINF: Decimal("Infinity"),
+ NUMERIC_NINF: Decimal("-Infinity"),
+}
+
+
+class _ContextMap(DefaultDict[int, Context]):
+ """
+ Cache for decimal contexts to use when the precision requires it.
+
+ Note: if the default context is used (prec=28) you can get an invalid
+ operation or a rounding to 0:
+
+ - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
+ - Decimal(1000).shift(25) = Decimal('0')
+ - Decimal(1000).shift(30) raises InvalidOperation
+ """
+
+ def __missing__(self, key: int) -> Context:
+ val = Context(prec=key)
+ self[key] = val
+ return val
+
+
+_contexts = _ContextMap()
+for i in range(DefaultContext.prec):
+ _contexts[i] = DefaultContext
+
+_unpack_numeric_head = cast(
+ Callable[[Buffer], Tuple[int, int, int, int]],
+ struct.Struct("!HhHH").unpack_from,
+)
+_pack_numeric_head = cast(
+ Callable[[int, int, int, int], bytes],
+ struct.Struct("!HhHH").pack,
+)
+
+
+class NumericBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Decimal:
+ ndigits, weight, sign, dscale = _unpack_numeric_head(data)
+ if sign == NUMERIC_POS or sign == NUMERIC_NEG:
+ val = 0
+ for i in range(8, len(data), 2):
+ val = val * 10_000 + data[i] * 0x100 + data[i + 1]
+
+ shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
+ ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
+ return (
+ Decimal(val if sign == NUMERIC_POS else -val)
+ .scaleb(-dscale, ctx)
+ .shift(shift, ctx)
+ )
+ else:
+ try:
+ return _decimal_special[sign]
+ except KeyError:
+ raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
+
+
+NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
+NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
+NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
+
+
+class DecimalBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Decimal) -> Buffer:
+ return dump_decimal_to_numeric_binary(obj)
+
+
+class NumericDumper(DecimalDumper):
+ def dump(self, obj: Union[Decimal, int]) -> bytes:
+ if isinstance(obj, int):
+ return str(obj).encode()
+ else:
+ return super().dump(obj)
+
+
+class NumericBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Union[Decimal, int]) -> Buffer:
+ if isinstance(obj, int):
+ return dump_int_to_numeric_binary(obj)
+ else:
+ return dump_decimal_to_numeric_binary(obj)
+
+
+def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
+ sign, digits, exp = obj.as_tuple()
+ if exp == "n" or exp == "N": # type: ignore[comparison-overlap]
+ return NUMERIC_NAN_BIN
+ elif exp == "F": # type: ignore[comparison-overlap]
+ return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
+
+ # Weights of py digits into a pg digit according to their positions.
+ # Starting with an index wi != 0 is equivalent to prepending 0's to
+ # the digits tuple, but without really changing it.
+ weights = (1000, 100, 10, 1)
+ wi = 0
+
+ ndigits = nzdigits = len(digits)
+
+ # Find the last nonzero digit
+ while nzdigits > 0 and digits[nzdigits - 1] == 0:
+ nzdigits -= 1
+
+ if exp <= 0:
+ dscale = -exp
+ else:
+ dscale = 0
+ # align the py digits to the pg digits if there's some py exponent
+ ndigits += exp % DEC_DIGITS
+
+ if not nzdigits:
+ return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
+
+ # Equivalent of 0-padding left to align the py digits to the pg digits
+ # but without changing the digits tuple.
+ mod = (ndigits - dscale) % DEC_DIGITS
+ if mod:
+ wi = DEC_DIGITS - mod
+ ndigits += wi
+
+ tmp = nzdigits + wi
+ out = bytearray(
+ _pack_numeric_head(
+ tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits
+ (ndigits + exp) // DEC_DIGITS - 1, # weight
+ NUMERIC_NEG if sign else NUMERIC_POS, # sign
+ dscale,
+ )
+ )
+
+ pgdigit = 0
+ for i in range(nzdigits):
+ pgdigit += weights[wi] * digits[i]
+ wi += 1
+ if wi >= DEC_DIGITS:
+ out += pack_uint2(pgdigit)
+ pgdigit = wi = 0
+
+ if pgdigit:
+ out += pack_uint2(pgdigit)
+
+ return out
+
+
+def dump_int_to_numeric_binary(obj: int) -> bytearray:
+ ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
+ out = bytearray(b"\x00\x00" * (ndigits + 4))
+ if obj < 0:
+ sign = NUMERIC_NEG
+ obj = -obj
+ else:
+ sign = NUMERIC_POS
+
+ out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
+ i = 8 + (ndigits - 1) * 2
+ while obj:
+ rem = obj % 10_000
+ obj //= 10_000
+ out[i : i + 2] = pack_uint2(rem)
+ i -= 2
+
+ return out
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(int, IntDumper)
+ adapters.register_dumper(int, IntBinaryDumper)
+ adapters.register_dumper(float, FloatDumper)
+ adapters.register_dumper(float, FloatBinaryDumper)
+ adapters.register_dumper(Int2, Int2Dumper)
+ adapters.register_dumper(Int4, Int4Dumper)
+ adapters.register_dumper(Int8, Int8Dumper)
+ adapters.register_dumper(IntNumeric, IntNumericDumper)
+ adapters.register_dumper(Oid, OidDumper)
+
+ # The binary dumper is currently some 30% slower, so default to text
+ # (see tests/scripts/testdec.py for a rough benchmark)
+ # Also, must be after IntNumericDumper
+ adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
+ adapters.register_dumper("decimal.Decimal", DecimalDumper)
+
+ # Used only by oid, can take both int and Decimal as input
+ adapters.register_dumper(None, NumericBinaryDumper)
+ adapters.register_dumper(None, NumericDumper)
+
+ adapters.register_dumper(Float4, Float4Dumper)
+ adapters.register_dumper(Float8, FloatDumper)
+ adapters.register_dumper(Int2, Int2BinaryDumper)
+ adapters.register_dumper(Int4, Int4BinaryDumper)
+ adapters.register_dumper(Int8, Int8BinaryDumper)
+ adapters.register_dumper(Oid, OidBinaryDumper)
+ adapters.register_dumper(Float4, Float4BinaryDumper)
+ adapters.register_dumper(Float8, FloatBinaryDumper)
+ adapters.register_loader("int2", IntLoader)
+ adapters.register_loader("int4", IntLoader)
+ adapters.register_loader("int8", IntLoader)
+ adapters.register_loader("oid", IntLoader)
+ adapters.register_loader("int2", Int2BinaryLoader)
+ adapters.register_loader("int4", Int4BinaryLoader)
+ adapters.register_loader("int8", Int8BinaryLoader)
+ adapters.register_loader("oid", OidBinaryLoader)
+ adapters.register_loader("float4", FloatLoader)
+ adapters.register_loader("float8", FloatLoader)
+ adapters.register_loader("float4", Float4BinaryLoader)
+ adapters.register_loader("float8", Float8BinaryLoader)
+ adapters.register_loader("numeric", NumericLoader)
+ adapters.register_loader("numeric", NumericBinaryLoader)
diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py
new file mode 100644
index 0000000..c418480
--- /dev/null
+++ b/psycopg/psycopg/types/range.py
@@ -0,0 +1,700 @@
+"""
+Support for range types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple
+from typing import cast
+from decimal import Decimal
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import RangeInfo as RangeInfo # exported here
+
+RANGE_EMPTY = 0x01 # range is empty
+RANGE_LB_INC = 0x02 # lower bound is inclusive
+RANGE_UB_INC = 0x04 # upper bound is inclusive
+RANGE_LB_INF = 0x08 # lower bound is -infinity
+RANGE_UB_INF = 0x10 # upper bound is +infinity
+
+_EMPTY_HEAD = bytes([RANGE_EMPTY])
+
+T = TypeVar("T")
+
+
+class Range(Generic[T]):
+ """Python representation for a PostgreSQL range type.
+
+ :param lower: lower bound for the range. `!None` means unbound
+ :param upper: upper bound for the range. `!None` means unbound
+ :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``,
+ representing whether the lower or upper bounds are included
+ :param empty: if `!True`, the range is empty
+
+ """
+
+ __slots__ = ("_lower", "_upper", "_bounds")
+
+ def __init__(
+ self,
+ lower: Optional[T] = None,
+ upper: Optional[T] = None,
+ bounds: str = "[)",
+ empty: bool = False,
+ ):
+ if not empty:
+ if bounds not in ("[)", "(]", "()", "[]"):
+ raise ValueError("bound flags not valid: %r" % bounds)
+
+ self._lower = lower
+ self._upper = upper
+
+ # Make bounds consistent with infs
+ if lower is None and bounds[0] == "[":
+ bounds = "(" + bounds[1]
+ if upper is None and bounds[1] == "]":
+ bounds = bounds[0] + ")"
+
+ self._bounds = bounds
+ else:
+ self._lower = self._upper = None
+ self._bounds = ""
+
+ def __repr__(self) -> str:
+ if self._bounds:
+ args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}"
+ else:
+ args = "empty=True"
+
+ return f"{self.__class__.__name__}({args})"
+
+ def __str__(self) -> str:
+ if not self._bounds:
+ return "empty"
+
+ items = [
+ self._bounds[0],
+ str(self._lower),
+ ", ",
+ str(self._upper),
+ self._bounds[1],
+ ]
+ return "".join(items)
+
+ @property
+ def lower(self) -> Optional[T]:
+ """The lower bound of the range. `!None` if empty or unbound."""
+ return self._lower
+
+ @property
+ def upper(self) -> Optional[T]:
+ """The upper bound of the range. `!None` if empty or unbound."""
+ return self._upper
+
+ @property
+ def bounds(self) -> str:
+ """The bounds string (two characters from '[', '(', ']', ')')."""
+ return self._bounds
+
+ @property
+ def isempty(self) -> bool:
+ """`!True` if the range is empty."""
+ return not self._bounds
+
+ @property
+ def lower_inf(self) -> bool:
+ """`!True` if the range doesn't have a lower bound."""
+ if not self._bounds:
+ return False
+ return self._lower is None
+
+ @property
+ def upper_inf(self) -> bool:
+ """`!True` if the range doesn't have an upper bound."""
+ if not self._bounds:
+ return False
+ return self._upper is None
+
+ @property
+ def lower_inc(self) -> bool:
+ """`!True` if the lower bound is included in the range."""
+ if not self._bounds or self._lower is None:
+ return False
+ return self._bounds[0] == "["
+
+ @property
+ def upper_inc(self) -> bool:
+ """`!True` if the upper bound is included in the range."""
+ if not self._bounds or self._upper is None:
+ return False
+ return self._bounds[1] == "]"
+
+ def __contains__(self, x: T) -> bool:
+ if not self._bounds:
+ return False
+
+ if self._lower is not None:
+ if self._bounds[0] == "[":
+ # It doesn't seem that Python has an ABC for ordered types.
+ if x < self._lower: # type: ignore[operator]
+ return False
+ else:
+ if x <= self._lower: # type: ignore[operator]
+ return False
+
+ if self._upper is not None:
+ if self._bounds[1] == "]":
+ if x > self._upper: # type: ignore[operator]
+ return False
+ else:
+ if x >= self._upper: # type: ignore[operator]
+ return False
+
+ return True
+
+ def __bool__(self) -> bool:
+ return bool(self._bounds)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return False
+ return (
+ self._lower == other._lower
+ and self._upper == other._upper
+ and self._bounds == other._bounds
+ )
+
+ def __hash__(self) -> int:
+ return hash((self._lower, self._upper, self._bounds))
+
+ # as the postgres docs describe for the server-side stuff,
+ # ordering is rather arbitrary, but will remain stable
+ # and consistent.
+
+ def __lt__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return NotImplemented
+ for attr in ("_lower", "_upper", "_bounds"):
+ self_value = getattr(self, attr)
+ other_value = getattr(other, attr)
+ if self_value == other_value:
+ pass
+ elif self_value is None:
+ return True
+ elif other_value is None:
+ return False
+ else:
+ return cast(bool, self_value < other_value)
+ return False
+
+ def __le__(self, other: Any) -> bool:
+ return self == other or self < other # type: ignore
+
+ def __gt__(self, other: Any) -> bool:
+ if isinstance(other, Range):
+ return other < self
+ else:
+ return NotImplemented
+
+ def __ge__(self, other: Any) -> bool:
+ return self == other or self > other # type: ignore
+
+ def __getstate__(self) -> Dict[str, Any]:
+ return {
+ slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
+ }
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ for slot, value in state.items():
+ setattr(self, slot, value)
+
+
+# Subclasses to specify a specific subtype. Usually not needed: only needed
+# in binary copy, where switching to text is not an option.
+
+
+class Int4Range(Range[int]):
+ pass
+
+
+class Int8Range(Range[int]):
+ pass
+
+
+class NumericRange(Range[Decimal]):
+ pass
+
+
+class DateRange(Range[date]):
+ pass
+
+
+class TimestampRange(Range[datetime]):
+ pass
+
+
+class TimestamptzRange(Range[datetime]):
+ pass
+
+
+class BaseRangeDumper(RecursiveDumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ self._adapt_format = PyFormat.from_pq(self.format)
+
+ def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Range:
+ return self.cls
+
+ item = self._get_item(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, self._adapt_format)
+ return (self.cls, sd.get_key(item, format))
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Range:
+ return self
+
+ item = self._get_item(obj)
+ if item is None:
+ return RangeDumper(self.cls)
+
+ dumper: BaseRangeDumper
+ if type(item) is int:
+ # postgres won't cast int4range -> int8range so we must use
+ # text format and unknown oid here
+ sd = self._tx.get_dumper(item, PyFormat.TEXT)
+ dumper = RangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = INVALID_OID
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
+ # Work around the normal mapping where text is dumped as unknown
+ dumper.oid = self._get_range_oid(TEXT_OID)
+ else:
+ dumper.oid = self._get_range_oid(sd.oid)
+
+ return dumper
+
+ def _get_item(self, obj: Range[Any]) -> Any:
+ """
+ Return a member representative of the range
+ """
+ rv = obj.lower
+ return rv if rv is not None else obj.upper
+
+ def _get_range_oid(self, sub_oid: int) -> int:
+ """
+ Return the oid of the range from the oid of its elements.
+ """
+ info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid)
+ return info.oid if info else INVALID_OID
+
+
+class RangeDumper(BaseRangeDumper):
+ """
+ Dumper for range types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ def dump(self, obj: Range[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ return dump_range_text(obj, dump)
+
+
+def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+ if obj.isempty:
+ return b"empty"
+
+ parts: List[Buffer] = [b"[" if obj.lower_inc else b"("]
+
+ def dump_item(item: Any) -> Buffer:
+ ad = dump(item)
+ if not ad:
+ return b'""'
+ elif _re_needs_quotes.search(ad):
+ return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"'
+ else:
+ return ad
+
+ if obj.lower is not None:
+ parts.append(dump_item(obj.lower))
+
+ parts.append(b",")
+
+ if obj.upper is not None:
+ parts.append(dump_item(obj.upper))
+
+ parts.append(b"]" if obj.upper_inc else b")")
+
+ return b"".join(parts)
+
+
+_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]')
+_re_esc = re.compile(rb"([\\\"])")
+
+
+class RangeBinaryDumper(BaseRangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Range[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ return dump_range_binary(obj, dump)
+
+
+def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+ if not obj:
+ return _EMPTY_HEAD
+
+ out = bytearray([0]) # will replace the head later
+
+ head = 0
+ if obj.lower_inc:
+ head |= RANGE_LB_INC
+ if obj.upper_inc:
+ head |= RANGE_UB_INC
+
+ if obj.lower is not None:
+ data = dump(obj.lower)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_LB_INF
+
+ if obj.upper is not None:
+ data = dump(obj.upper)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_UB_INF
+
+ out[0] = head
+ return out
+
+
+def fail_dump(obj: Any) -> Buffer:
+ raise e.InternalError("trying to dump a range element without information")
+
+
+class BaseRangeLoader(RecursiveLoader, Generic[T]):
+ """Generic loader for a range.
+
+ Subclasses must specify the oid of the subtype and the class to load.
+ """
+
+ subtype_oid: int
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
+
+
+class RangeLoader(BaseRangeLoader[T]):
+ def load(self, data: Buffer) -> Range[T]:
+ return load_range_text(data, self._load)[0]
+
+
+def load_range_text(
+ data: Buffer, load: Callable[[Buffer], Any]
+) -> Tuple[Range[Any], int]:
+ if data == b"empty":
+ return Range(empty=True), 5
+
+ m = _re_range.match(data)
+ if m is None:
+ raise e.DataError(
+ f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
+ )
+
+ lower = None
+ item = m.group(3)
+ if item is None:
+ item = m.group(2)
+ if item is not None:
+ lower = load(_re_undouble.sub(rb"\1", item))
+ else:
+ lower = load(item)
+
+ upper = None
+ item = m.group(5)
+ if item is None:
+ item = m.group(4)
+ if item is not None:
+ upper = load(_re_undouble.sub(rb"\1", item))
+ else:
+ upper = load(item)
+
+ bounds = (m.group(1) + m.group(6)).decode()
+
+ return Range(lower, upper, bounds), m.end()
+
+
+_re_range = re.compile(
+ rb"""
+ ( \(|\[ ) # lower bound flag
+ (?: # lower bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^",]+ ) # - or an unquoted string
+ )? # - or empty (not caught)
+ ,
+ (?: # upper bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^"\)\]]+ ) # - or an unquoted string
+ )? # - or empty (not caught)
+ ( \)|\] ) # upper bound flag
+ """,
+ re.VERBOSE,
+)
+
+_re_undouble = re.compile(rb'(["\\])\1')
+
+
+class RangeBinaryLoader(BaseRangeLoader[T]):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Range[T]:
+ return load_range_binary(data, self._load)
+
+
+def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]:
+ head = data[0]
+ if head & RANGE_EMPTY:
+ return Range(empty=True)
+
+ lb = "[" if head & RANGE_LB_INC else "("
+ ub = "]" if head & RANGE_UB_INC else ")"
+
+ pos = 1 # after the head
+ if head & RANGE_LB_INF:
+ min = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ min = load(data[pos : pos + length])
+ pos += length
+
+ if head & RANGE_UB_INF:
+ max = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ max = load(data[pos : pos + length])
+ pos += length
+
+ return Range(min, max, lb + ub)
+
+
+def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register the adapters to load and dump a range type.
+
+ :param info: The object with the information about the range to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ Register loaders so that loading data of this type will result in a `Range`
+ with bounds parsed as the right subtype.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested range available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[RangeLoader[Any]] = type(
+ f"{info.name.title()}Loader",
+ (RangeLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ bloader: Type[RangeBinaryLoader[Any]] = type(
+ f"{info.name.title()}BinaryLoader",
+ (RangeBinaryLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, bloader)
+
+
+# Text dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeDumper(RangeDumper):
+ oid = postgres.types["int4range"].oid
+
+
+class Int8RangeDumper(RangeDumper):
+ oid = postgres.types["int8range"].oid
+
+
+class NumericRangeDumper(RangeDumper):
+ oid = postgres.types["numrange"].oid
+
+
+class DateRangeDumper(RangeDumper):
+ oid = postgres.types["daterange"].oid
+
+
+class TimestampRangeDumper(RangeDumper):
+ oid = postgres.types["tsrange"].oid
+
+
+class TimestamptzRangeDumper(RangeDumper):
+ oid = postgres.types["tstzrange"].oid
+
+
+# Binary dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["int4range"].oid
+
+
+class Int8RangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["int8range"].oid
+
+
+class NumericRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["numrange"].oid
+
+
+class DateRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["daterange"].oid
+
+
+class TimestampRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["tsrange"].oid
+
+
+class TimestamptzRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["tstzrange"].oid
+
+
+# Text loaders for builtin range types
+
+
+class Int4RangeLoader(RangeLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8RangeLoader(RangeLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericRangeLoader(RangeLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateRangeLoader(RangeLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampRangeLoader(RangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZRangeLoader(RangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin range types
+
+
+class Int4RangeBinaryLoader(RangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8RangeBinaryLoader(RangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateRangeBinaryLoader(RangeBinaryLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Range, RangeBinaryDumper)
+ adapters.register_dumper(Range, RangeDumper)
+ adapters.register_dumper(Int4Range, Int4RangeDumper)
+ adapters.register_dumper(Int8Range, Int8RangeDumper)
+ adapters.register_dumper(NumericRange, NumericRangeDumper)
+ adapters.register_dumper(DateRange, DateRangeDumper)
+ adapters.register_dumper(TimestampRange, TimestampRangeDumper)
+ adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper)
+ adapters.register_dumper(Int4Range, Int4RangeBinaryDumper)
+ adapters.register_dumper(Int8Range, Int8RangeBinaryDumper)
+ adapters.register_dumper(NumericRange, NumericRangeBinaryDumper)
+ adapters.register_dumper(DateRange, DateRangeBinaryDumper)
+ adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper)
+ adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper)
+ adapters.register_loader("int4range", Int4RangeLoader)
+ adapters.register_loader("int8range", Int8RangeLoader)
+ adapters.register_loader("numrange", NumericRangeLoader)
+ adapters.register_loader("daterange", DateRangeLoader)
+ adapters.register_loader("tsrange", TimestampRangeLoader)
+ adapters.register_loader("tstzrange", TimestampTZRangeLoader)
+ adapters.register_loader("int4range", Int4RangeBinaryLoader)
+ adapters.register_loader("int8range", Int8RangeBinaryLoader)
+ adapters.register_loader("numrange", NumericRangeBinaryLoader)
+ adapters.register_loader("daterange", DateRangeBinaryLoader)
+ adapters.register_loader("tsrange", TimestampRangeBinaryLoader)
+ adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader)
diff --git a/psycopg/psycopg/types/shapely.py b/psycopg/psycopg/types/shapely.py
new file mode 100644
index 0000000..e99f256
--- /dev/null
+++ b/psycopg/psycopg/types/shapely.py
@@ -0,0 +1,75 @@
+"""
+Adapters for PostGIS geometries
+"""
+
+from typing import Optional
+
+from .. import postgres
+from ..abc import AdaptContext, Buffer
+from ..adapt import Dumper, Loader
+from ..pq import Format
+from .._typeinfo import TypeInfo
+
+
+try:
+ from shapely.wkb import loads, dumps
+ from shapely.geometry.base import BaseGeometry
+
+except ImportError:
+ raise ImportError(
+ "The module psycopg.types.shapely requires the package 'Shapely'"
+ " to be installed"
+ )
+
+
+class GeometryBinaryLoader(Loader):
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> "BaseGeometry":
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return loads(data)
+
+
+class GeometryLoader(Loader):
+ def load(self, data: Buffer) -> "BaseGeometry":
+ # it's a hex string in binary
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return loads(data.decode(), hex=True)
+
+
+class BaseGeometryBinaryDumper(Dumper):
+ format = Format.BINARY
+
+ def dump(self, obj: "BaseGeometry") -> bytes:
+ return dumps(obj) # type: ignore
+
+
+class BaseGeometryDumper(Dumper):
+ def dump(self, obj: "BaseGeometry") -> bytes:
+ return dumps(obj, hex=True).encode() # type: ignore
+
+
+def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register Shapely dumper and loaders."""
+
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the 'postgis' extension loaded?")
+
+ info.register(context)
+ adapters = context.adapters if context else postgres.adapters
+
+ class GeometryDumper(BaseGeometryDumper):
+ oid = info.oid
+
+ class GeometryBinaryDumper(BaseGeometryBinaryDumper):
+ oid = info.oid
+
+ adapters.register_loader(info.oid, GeometryBinaryLoader)
+ adapters.register_loader(info.oid, GeometryLoader)
+ # Default binary dump
+ adapters.register_dumper(BaseGeometry, GeometryDumper)
+ adapters.register_dumper(BaseGeometry, GeometryBinaryDumper)
diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py
new file mode 100644
index 0000000..cd5360d
--- /dev/null
+++ b/psycopg/psycopg/types/string.py
@@ -0,0 +1,239 @@
+"""
+Adapters for textual types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Optional, Union, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format, Escaping
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+from ..errors import DataError
+from .._encodings import conn_encoding
+
+if TYPE_CHECKING:
+ from ..pq.abc import Escaping as EscapingProto
+
+
+class _BaseStrDumper(Dumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ enc = conn_encoding(self.connection)
+ self._encoding = enc if enc != "ascii" else "utf-8"
+
+
+class _StrBinaryDumper(_BaseStrDumper):
+ """
+ Base class to dump a Python strings to a Postgres text type, in binary format.
+
+ Subclasses shall specify the oids of real types (text, varchar, name...).
+ """
+
+ format = Format.BINARY
+
+ def dump(self, obj: str) -> bytes:
+ # the server will raise DataError subclass if the string contains 0x00
+ return obj.encode(self._encoding)
+
+
+class _StrDumper(_BaseStrDumper):
+ """
+ Base class to dump a Python strings to a Postgres text type, in text format.
+
+ Subclasses shall specify the oids of real types (text, varchar, name...).
+ """
+
+ def dump(self, obj: str) -> bytes:
+ if "\x00" in obj:
+ raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
+ else:
+ return obj.encode(self._encoding)
+
+
+# The next are concrete dumpers, each one specifying the oid they dump to.
+
+
+class StrBinaryDumper(_StrBinaryDumper):
+
+ oid = postgres.types["text"].oid
+
+
+class StrBinaryDumperVarchar(_StrBinaryDumper):
+
+ oid = postgres.types["varchar"].oid
+
+
+class StrBinaryDumperName(_StrBinaryDumper):
+
+ oid = postgres.types["name"].oid
+
+
+class StrDumper(_StrDumper):
+ """
+ Dumper for strings in text format to the text oid.
+
+ Note that this dumper is not used by default because the type is too strict
+ and PostgreSQL would require an explicit casts to everything that is not a
+ text field. However it is useful where the unknown oid is ambiguous and the
+ text oid is required, for instance with variadic functions.
+ """
+
+ oid = postgres.types["text"].oid
+
+
+class StrDumperVarchar(_StrDumper):
+
+ oid = postgres.types["varchar"].oid
+
+
+class StrDumperName(_StrDumper):
+
+ oid = postgres.types["name"].oid
+
+
+class StrDumperUnknown(_StrDumper):
+ """
+ Dumper for strings in text format to the unknown oid.
+
+ This dumper is the default dumper for strings and allows to use Python
+ strings to represent almost every data type. In a few places, however, the
+ unknown oid is not accepted (for instance in variadic functions such as
+ 'concat()'). In that case either a cast on the placeholder ('%s::text') or
+ the StrTextDumper should be used.
+ """
+
+ pass
+
+
+class TextLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ enc = conn_encoding(self.connection)
+ self._encoding = enc if enc != "ascii" else ""
+
+ def load(self, data: Buffer) -> Union[bytes, str]:
+ if self._encoding:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return data.decode(self._encoding)
+ else:
+ # return bytes for SQL_ASCII db
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return data
+
+
+class TextBinaryLoader(TextLoader):
+
+ format = Format.BINARY
+
+
+class BytesDumper(Dumper):
+
+ oid = postgres.types["bytea"].oid
+ _qprefix = b""
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._esc = Escaping(self.connection.pgconn if self.connection else None)
+
+ def dump(self, obj: Buffer) -> Buffer:
+ return self._esc.escape_bytea(obj)
+
+ def quote(self, obj: Buffer) -> bytes:
+ escaped = self.dump(obj)
+
+ # We cannot use the base quoting because escape_bytea already returns
+ # the quotes content. if scs is off it will escape the backslashes in
+ # the format, otherwise it won't, but it doesn't tell us what quotes to
+ # use.
+ if self.connection:
+ if not self._qprefix:
+ scs = self.connection.pgconn.parameter_status(
+ b"standard_conforming_strings"
+ )
+ self._qprefix = b"'" if scs == b"on" else b" E'"
+
+ return self._qprefix + escaped + b"'"
+
+ # We don't have a connection, so someone is using us to generate a file
+ # to use off-line or something like that. PQescapeBytea, like its
+ # string counterpart, is not predictable whether it will escape
+ # backslashes.
+ rv: bytes = b" E'" + escaped + b"'"
+ if self._esc.escape_bytea(b"\x00") == b"\\000":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
+
+
+class BytesBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["bytea"].oid
+
+ def dump(self, obj: Buffer) -> Buffer:
+ return obj
+
+
+class ByteaLoader(Loader):
+
+ _escaping: "EscapingProto"
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ if not hasattr(self.__class__, "_escaping"):
+ self.__class__._escaping = Escaping()
+
+ def load(self, data: Buffer) -> bytes:
+ return self._escaping.unescape_bytea(data)
+
+
+class ByteaBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Buffer:
+ return data
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+
+ # NOTE: the order the dumpers are registered is relevant. The last one
+ # registered becomes the default for each type. Usually, binary is the
+ # default dumper. For text we use the text dumper as default because it
+ # plays the role of unknown, and it can be cast automatically to other
+ # types. However, before that, we register dumper with 'text', 'varchar',
+ # 'name' oids, which will be used when a text dumper is looked up by oid.
+ adapters.register_dumper(str, StrBinaryDumperName)
+ adapters.register_dumper(str, StrBinaryDumperVarchar)
+ adapters.register_dumper(str, StrBinaryDumper)
+ adapters.register_dumper(str, StrDumperName)
+ adapters.register_dumper(str, StrDumperVarchar)
+ adapters.register_dumper(str, StrDumper)
+ adapters.register_dumper(str, StrDumperUnknown)
+
+ adapters.register_loader(postgres.INVALID_OID, TextLoader)
+ adapters.register_loader("bpchar", TextLoader)
+ adapters.register_loader("name", TextLoader)
+ adapters.register_loader("text", TextLoader)
+ adapters.register_loader("varchar", TextLoader)
+ adapters.register_loader('"char"', TextLoader)
+ adapters.register_loader("bpchar", TextBinaryLoader)
+ adapters.register_loader("name", TextBinaryLoader)
+ adapters.register_loader("text", TextBinaryLoader)
+ adapters.register_loader("varchar", TextBinaryLoader)
+ adapters.register_loader('"char"', TextBinaryLoader)
+
+ adapters.register_dumper(bytes, BytesDumper)
+ adapters.register_dumper(bytearray, BytesDumper)
+ adapters.register_dumper(memoryview, BytesDumper)
+ adapters.register_dumper(bytes, BytesBinaryDumper)
+ adapters.register_dumper(bytearray, BytesBinaryDumper)
+ adapters.register_dumper(memoryview, BytesBinaryDumper)
+
+ adapters.register_loader("bytea", ByteaLoader)
+ adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader)
+ adapters.register_loader("bytea", ByteaBinaryLoader)
diff --git a/psycopg/psycopg/types/uuid.py b/psycopg/psycopg/types/uuid.py
new file mode 100644
index 0000000..f92354c
--- /dev/null
+++ b/psycopg/psycopg/types/uuid.py
@@ -0,0 +1,65 @@
+"""
+Adapters for the UUID type.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Callable, Optional, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+if TYPE_CHECKING:
+ import uuid
+
+# Importing the uuid module is slow, so import it only on request.
+UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment]
+
+
+class UUIDDumper(Dumper):
+
+ oid = postgres.types["uuid"].oid
+
+ def dump(self, obj: "uuid.UUID") -> bytes:
+ return obj.hex.encode()
+
+
+class UUIDBinaryDumper(UUIDDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: "uuid.UUID") -> bytes:
+ return obj.bytes
+
+
+class UUIDLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ global UUID
+ if UUID is None:
+ from uuid import UUID
+
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return UUID(data.decode())
+
+
+class UUIDBinaryLoader(UUIDLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return UUID(bytes=data)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("uuid.UUID", UUIDDumper)
+ adapters.register_dumper("uuid.UUID", UUIDBinaryDumper)
+ adapters.register_loader("uuid", UUIDLoader)
+ adapters.register_loader("uuid", UUIDBinaryLoader)
diff --git a/psycopg/psycopg/version.py b/psycopg/psycopg/version.py
new file mode 100644
index 0000000..a98bc35
--- /dev/null
+++ b/psycopg/psycopg/version.py
@@ -0,0 +1,14 @@
+"""
+psycopg distribution version file.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Use a versioning scheme as defined in
+# https://www.python.org/dev/peps/pep-0440/
+
+# STOP AND READ! if you change:
+__version__ = "3.1.7"
+# also change:
+# - `docs/news.rst` to declare this as the current version or an unreleased one
+# - `psycopg_c/psycopg_c/version.py` to the same version.
diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py
new file mode 100644
index 0000000..7abfc58
--- /dev/null
+++ b/psycopg/psycopg/waiting.py
@@ -0,0 +1,331 @@
+"""
+Code concerned with waiting in different contexts (blocking, async, etc).
+
+These functions are designed to consume the generators returned by the
+`generators` module function and to return their final value.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+
+import os
+import select
+import selectors
+from typing import Dict, Optional
+from asyncio import get_event_loop, wait_for, Event, TimeoutError
+from selectors import DefaultSelector
+
+from . import errors as e
+from .abc import RV, PQGen, PQGenConn, WaitFunc
+from ._enums import Wait as Wait, Ready as Ready # re-exported
+from ._cmodule import _psycopg
+
+WAIT_R = Wait.R
+WAIT_W = Wait.W
+WAIT_RW = Wait.RW
+READY_R = Ready.R
+READY_W = Ready.W
+READY_RW = Ready.RW
+
+
+def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using the best strategy available.
+
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C.
+ :type timeout: float
+ :return: whatever `!gen` returns on completion.
+
+ Consume `!gen`, scheduling `fileno` for completion when it is reported to
+ block. Once ready again send the ready state back to `!gen`.
+ """
+ try:
+ s = next(gen)
+ with DefaultSelector() as sel:
+ while True:
+ sel.register(fileno, s)
+ rlist = None
+ while not rlist:
+ rlist = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ # note: this line should require a cast, but mypy doesn't complain
+ ready: Ready = rlist[0][1]
+ assert s & ready
+ s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a connection generator using the best strategy available.
+
+ :param gen: a generator performing database operations and yielding
+ (fd, `Ready`) pairs when it would block.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
+ :type timeout: float
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
+ """
+ try:
+ fileno, s = next(gen)
+ if not timeout:
+ timeout = None
+ with DefaultSelector() as sel:
+ while True:
+ sel.register(fileno, s)
+ rlist = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ if not rlist:
+ raise e.ConnectionTimeout("connection timeout expired")
+ ready: Ready = rlist[0][1] # type: ignore[assignment]
+ fileno, s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
+ """
+ Coroutine waiting for a generator to complete.
+
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but exposing an `asyncio` interface.
+ """
+ # Use an event to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ ev = Event()
+ loop = get_event_loop()
+ ready: Ready
+ s: Wait
+
+ def wakeup(state: Ready) -> None:
+ nonlocal ready
+ ready |= state # type: ignore[assignment]
+ ev.set()
+
+ try:
+ s = next(gen)
+ while True:
+ reader = s & WAIT_R
+ writer = s & WAIT_W
+ if not reader and not writer:
+ raise e.InternalError(f"bad poll status: {s}")
+ ev.clear()
+ ready = 0 # type: ignore[assignment]
+ if reader:
+ loop.add_reader(fileno, wakeup, READY_R)
+ if writer:
+ loop.add_writer(fileno, wakeup, READY_W)
+ try:
+ await ev.wait()
+ finally:
+ if reader:
+ loop.remove_reader(fileno)
+ if writer:
+ loop.remove_writer(fileno)
+ s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+ """
+ Coroutine waiting for a connection generator to complete.
+
+ :param gen: a generator performing database operations and yielding
+ (fd, `Ready`) pairs when it would block.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
+ """
+ # Use an event to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ ev = Event()
+ loop = get_event_loop()
+ ready: Ready
+ s: Wait
+
+ def wakeup(state: Ready) -> None:
+ nonlocal ready
+ ready = state
+ ev.set()
+
+ try:
+ fileno, s = next(gen)
+ if not timeout:
+ timeout = None
+ while True:
+ reader = s & WAIT_R
+ writer = s & WAIT_W
+ if not reader and not writer:
+ raise e.InternalError(f"bad poll status: {s}")
+ ev.clear()
+ ready = 0 # type: ignore[assignment]
+ if reader:
+ loop.add_reader(fileno, wakeup, READY_R)
+ if writer:
+ loop.add_writer(fileno, wakeup, READY_W)
+ try:
+ await wait_for(ev.wait(), timeout)
+ finally:
+ if reader:
+ loop.remove_reader(fileno)
+ if writer:
+ loop.remove_writer(fileno)
+ fileno, s = gen.send(ready)
+
+ except TimeoutError:
+ raise e.ConnectionTimeout("connection timeout expired")
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+# Specialised implementation of wait functions.
+
+
+def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using select where supported.
+ """
+ try:
+ s = next(gen)
+
+ empty = ()
+ fnlist = (fileno,)
+ while True:
+ rl, wl, xl = select.select(
+ fnlist if s & WAIT_R else empty,
+ fnlist if s & WAIT_W else empty,
+ fnlist,
+ timeout,
+ )
+ ready = 0
+ if rl:
+ ready = READY_R
+ if wl:
+ ready |= READY_W
+ if not ready:
+ continue
+ # assert s & ready
+ s = gen.send(ready) # type: ignore
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+poll_evmasks: Dict[Wait, int]
+
+if hasattr(selectors, "EpollSelector"):
+ poll_evmasks = {
+ WAIT_R: select.EPOLLONESHOT | select.EPOLLIN,
+ WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT,
+ WAIT_RW: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLOUT,
+ }
+else:
+ poll_evmasks = {}
+
+
+def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using epoll where supported.
+
+ Parameters are like for `wait()`. If it is detected that the best selector
+ strategy is `epoll` then this function will be used instead of `wait`.
+
+ See also: https://linux.die.net/man/2/epoll_ctl
+ """
+ try:
+ s = next(gen)
+
+ if timeout is None or timeout < 0:
+ timeout = 0
+ else:
+ timeout = int(timeout * 1000.0)
+
+ with select.epoll() as epoll:
+ evmask = poll_evmasks[s]
+ epoll.register(fileno, evmask)
+ while True:
+ fileevs = None
+ while not fileevs:
+ fileevs = epoll.poll(timeout)
+ ev = fileevs[0][1]
+ ready = 0
+ if ev & ~select.EPOLLOUT:
+ ready = READY_R
+ if ev & ~select.EPOLLIN:
+ ready |= READY_W
+ # assert s & ready
+ s = gen.send(ready)
+ evmask = poll_evmasks[s]
+ epoll.modify(fileno, evmask)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+if _psycopg:
+ wait_c = _psycopg.wait_c
+
+
+# Choose the best wait strategy for the platform.
+#
+# the selectors objects have a generic interface but come with some overhead,
+# so we also offer more finely tuned implementations.
+
+wait: WaitFunc
+
+# Allow the user to choose a specific function for testing
+if "PSYCOPG_WAIT_FUNC" in os.environ:
+ fname = os.environ["PSYCOPG_WAIT_FUNC"]
+ if not fname.startswith("wait_") or fname not in globals():
+ raise ImportError(
+ "PSYCOPG_WAIT_FUNC should be the name of an available wait function;"
+ f" got {fname!r}"
+ )
+ wait = globals()[fname]
+
+elif _psycopg:
+ wait = wait_c
+
+elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None):
+ # On Windows, SelectSelector should be the default.
+ wait = wait_select
+
+elif selectors.DefaultSelector is getattr(selectors, "EpollSelector", None):
+ # NOTE: select seems more performing than epoll. It is admittedly unlikely
+ # that a platform has epoll but not select, so maybe we could kill
+ # wait_epoll altogether(). More testing to do.
+ wait = wait_select if hasattr(selectors, "SelectSelector") else wait_epoll
+
+elif selectors.DefaultSelector is getattr(selectors, "KqueueSelector", None):
+ # wait_select is faster than wait_selector, probably because of less overhead
+ wait = wait_select if hasattr(selectors, "SelectSelector") else wait_selector
+
+else:
+ wait = wait_selector
diff --git a/psycopg/pyproject.toml b/psycopg/pyproject.toml
new file mode 100644
index 0000000..21e410c
--- /dev/null
+++ b/psycopg/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=49.2.0", "wheel>=0.37"]
+build-backend = "setuptools.build_meta"
diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg
new file mode 100644
index 0000000..fdcb612
--- /dev/null
+++ b/psycopg/setup.cfg
@@ -0,0 +1,47 @@
+[metadata]
+name = psycopg
+description = PostgreSQL database adapter for Python
+url = https://psycopg.org/psycopg3/
+author = Daniele Varrazzo
+author_email = daniele.varrazzo@gmail.com
+license = GNU Lesser General Public License v3 (LGPLv3)
+
+project_urls =
+ Homepage = https://psycopg.org/
+ Code = https://github.com/psycopg/psycopg
+ Issue Tracker = https://github.com/psycopg/psycopg/issues
+ Download = https://pypi.org/project/psycopg/
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
+ Operating System :: MacOS :: MacOS X
+ Operating System :: Microsoft :: Windows
+ Operating System :: POSIX
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Topic :: Database
+ Topic :: Database :: Front-Ends
+ Topic :: Software Development
+ Topic :: Software Development :: Libraries :: Python Modules
+
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+license_files = LICENSE.txt
+
+[options]
+python_requires = >= 3.7
+packages = find:
+zip_safe = False
+install_requires =
+ backports.zoneinfo >= 0.2.0; python_version < "3.9"
+ typing-extensions >= 4.1
+ tzdata; sys_platform == "win32"
+
+[options.package_data]
+psycopg = py.typed
diff --git a/psycopg/setup.py b/psycopg/setup.py
new file mode 100644
index 0000000..90d4380
--- /dev/null
+++ b/psycopg/setup.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+"""
+PostgreSQL database adapter for Python - pure Python package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+from setuptools import setup
+
+# Move to the directory of setup.py: executing this file from another location
+# (e.g. from the project root) will fail
+here = os.path.abspath(os.path.dirname(__file__))
+if os.path.abspath(os.getcwd()) != here:
+ os.chdir(here)
+
+# Only for release 3.1.7. Not building binary packages because Scaleway
+# has no runner available, but psycopg-binary 3.1.6 should work as well
+# as the only change is in rows.py.
+version = "3.1.7"
+ext_versions = ">= 3.1.6, <= 3.1.7"
+
+extras_require = {
+ # Install the C extension module (requires dev tools)
+ "c": [
+ f"psycopg-c {ext_versions}",
+ ],
+ # Install the stand-alone C extension module
+ "binary": [
+ f"psycopg-binary {ext_versions}",
+ ],
+ # Install the connection pool
+ "pool": [
+ "psycopg-pool",
+ ],
+ # Requirements to run the test suite
+ "test": [
+ "mypy >= 0.990",
+ "pproxy >= 2.7",
+ "pytest >= 6.2.5",
+ "pytest-asyncio >= 0.17",
+ "pytest-cov >= 3.0",
+ "pytest-randomly >= 3.10",
+ ],
+ # Requirements needed for development
+ "dev": [
+ "black >= 22.3.0",
+ "dnspython >= 2.1",
+ "flake8 >= 4.0",
+ "mypy >= 0.990",
+ "types-setuptools >= 57.4",
+ "wheel >= 0.37",
+ ],
+ # Requirements needed to build the documentation
+ "docs": [
+ "Sphinx >= 5.0",
+ "furo == 2022.6.21",
+ "sphinx-autobuild >= 2021.3.14",
+ "sphinx-autodoc-typehints >= 1.12",
+ ],
+}
+
+setup(
+ version=version,
+ extras_require=extras_require,
+)
diff --git a/psycopg_c/.flake8 b/psycopg_c/.flake8
new file mode 100644
index 0000000..2ae629c
--- /dev/null
+++ b/psycopg_c/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
diff --git a/psycopg_c/LICENSE.txt b/psycopg_c/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/psycopg_c/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/psycopg_c/README-binary.rst b/psycopg_c/README-binary.rst
new file mode 100644
index 0000000..9318d57
--- /dev/null
+++ b/psycopg_c/README-binary.rst
@@ -0,0 +1,29 @@
+Psycopg 3: PostgreSQL database adapter for Python - binary package
+==================================================================
+
+This distribution contains the precompiled optimization package
+``psycopg_binary``.
+
+You shouldn't install this package directly: use instead ::
+
+ pip install "psycopg[binary]"
+
+to install a version of the optimization package matching the ``psycopg``
+version installed.
+
+Installing this package requires pip >= 20.3 or newer installed.
+
+This package is not available for every platform: check out `Binary
+installation`__ in the documentation.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #binary-installation
+
+Please read `the project readme`__ and `the installation documentation`__ for
+more details.
+
+.. __: https://github.com/psycopg/psycopg#readme
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg_c/README.rst b/psycopg_c/README.rst
new file mode 100644
index 0000000..de9ba93
--- /dev/null
+++ b/psycopg_c/README.rst
@@ -0,0 +1,33 @@
+Psycopg 3: PostgreSQL database adapter for Python - optimisation package
+========================================================================
+
+This distribution contains the optional optimization package ``psycopg_c``.
+
+You shouldn't install this package directly: use instead ::
+
+ pip install "psycopg[c]"
+
+to install a version of the optimization package matching the ``psycopg``
+version installed.
+
+Installing this package requires some prerequisites: check `Local
+installation`__ in the documentation. Without a C compiler and some library
+headers install *will fail*: this is not a bug.
+
+If you are unable to meet the prerequisite needed you might want to install
+``psycopg[binary]`` instead: look for `Binary installation`__ in the
+documentation.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #local-installation
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #binary-installation
+
+Please read `the project readme`__ and `the installation documentation`__ for
+more details.
+
+.. __: https://github.com/psycopg/psycopg#readme
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg_c/psycopg_c/.gitignore b/psycopg_c/psycopg_c/.gitignore
new file mode 100644
index 0000000..36edb64
--- /dev/null
+++ b/psycopg_c/psycopg_c/.gitignore
@@ -0,0 +1,4 @@
+/*.so
+_psycopg.c
+pq.c
+*.html
diff --git a/psycopg_c/psycopg_c/__init__.py b/psycopg_c/psycopg_c/__init__.py
new file mode 100644
index 0000000..14db92b
--- /dev/null
+++ b/psycopg_c/psycopg_c/__init__.py
@@ -0,0 +1,14 @@
+"""
+psycopg -- PostgreSQL database adapter for Python -- C optimization package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+
+# This package shouldn't be imported before psycopg itself, or weird things
+# will happen
+if "psycopg" not in sys.modules:
+ raise ImportError("the psycopg package should be imported before psycopg_c")
+
+from .version import __version__ as __version__ # noqa
diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi
new file mode 100644
index 0000000..bd7c63d
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg.pyi
@@ -0,0 +1,84 @@
+"""
+Stub representaton of the public objects exposed by the _psycopg module.
+
+TODO: this should be generated by mypy's stubgen but it crashes with no
+information. Will submit a bug.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Iterable, List, Optional, Sequence, Tuple
+
+from psycopg import pq
+from psycopg import abc
+from psycopg.rows import Row, RowMaker
+from psycopg.adapt import AdaptersMap, PyFormat
+from psycopg.pq.abc import PGconn, PGresult
+from psycopg.connection import BaseConnection
+from psycopg._compat import Deque
+
+class Transformer(abc.AdaptContext):
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+ def __init__(self, context: Optional[abc.AdaptContext] = None): ...
+ @classmethod
+ def from_context(cls, context: Optional[abc.AdaptContext]) -> "Transformer": ...
+ @property
+ def connection(self) -> Optional[BaseConnection[Any]]: ...
+ @property
+ def encoding(self) -> str: ...
+ @property
+ def adapters(self) -> AdaptersMap: ...
+ @property
+ def pgresult(self) -> Optional[PGresult]: ...
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None,
+ ) -> None: ...
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[abc.Buffer]]: ...
+ def as_literal(self, obj: Any) -> bytes: ...
+ def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
+ def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ...
+ def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ...
+ def load_sequence(
+ self, record: Sequence[Optional[abc.Buffer]]
+ ) -> Tuple[Any, ...]: ...
+ def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
+
+# Generators
+def connect(conninfo: str) -> abc.PQGenConn[PGconn]: ...
+def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
+def send(pgconn: PGconn) -> abc.PQGen[None]: ...
+def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
+def fetch(pgconn: PGconn) -> abc.PQGen[Optional[PGresult]]: ...
+def pipeline_communicate(
+ pgconn: PGconn, commands: Deque[abc.PipelineCommand]
+) -> abc.PQGen[List[List[PGresult]]]: ...
+def wait_c(
+ gen: abc.PQGen[abc.RV], fileno: int, timeout: Optional[float] = None
+) -> abc.RV: ...
+
+# Copy support
+def format_row_text(
+ row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None
+) -> bytearray: ...
+def format_row_binary(
+ row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None
+) -> bytearray: ...
+def parse_row_text(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ...
+def parse_row_binary(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ...
+
+# Arrays optimization
+def array_load_text(
+ data: abc.Buffer, loader: abc.Loader, delimiter: bytes = b","
+) -> List[Any]: ...
+def array_load_binary(data: abc.Buffer, tx: abc.Transformer) -> List[Any]: ...
+
+# vim: set syntax=python:
diff --git a/psycopg_c/psycopg_c/_psycopg.pyx b/psycopg_c/psycopg_c/_psycopg.pyx
new file mode 100644
index 0000000..9d2b8ba
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg.pyx
@@ -0,0 +1,48 @@
+"""
+psycopg_c._psycopg optimization module.
+
+The module contains optimized C code used in preference to Python code
+if a compiler is available.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c cimport pq
+from psycopg_c.pq cimport libpq
+from psycopg_c._psycopg cimport oids
+
+import logging
+
+from psycopg.pq import Format as _pq_Format
+from psycopg._enums import PyFormat as _py_Format
+
+logger = logging.getLogger("psycopg")
+
+PQ_TEXT = _pq_Format.TEXT
+PQ_BINARY = _pq_Format.BINARY
+
+PG_AUTO = _py_Format.AUTO
+PG_TEXT = _py_Format.TEXT
+PG_BINARY = _py_Format.BINARY
+
+
+cdef extern from *:
+ """
+#ifndef ARRAYSIZE
+#define ARRAYSIZE(a) ((sizeof(a) / sizeof(*(a))))
+#endif
+ """
+ int ARRAYSIZE(void *array)
+
+
+include "_psycopg/adapt.pyx"
+include "_psycopg/copy.pyx"
+include "_psycopg/generators.pyx"
+include "_psycopg/transform.pyx"
+include "_psycopg/waiting.pyx"
+
+include "types/array.pyx"
+include "types/datetime.pyx"
+include "types/numeric.pyx"
+include "types/bool.pyx"
+include "types/string.pyx"
diff --git a/psycopg_c/psycopg_c/_psycopg/__init__.pxd b/psycopg_c/psycopg_c/_psycopg/__init__.pxd
new file mode 100644
index 0000000..db22deb
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/__init__.pxd
@@ -0,0 +1,9 @@
+"""
+psycopg_c._psycopg cython module.
+
+This file is necessary to allow c-importing pxd files from this directory.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c._psycopg cimport oids
diff --git a/psycopg_c/psycopg_c/_psycopg/adapt.pyx b/psycopg_c/psycopg_c/_psycopg/adapt.pyx
new file mode 100644
index 0000000..a6d8e6a
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/adapt.pyx
@@ -0,0 +1,171 @@
+"""
+C implementation of the adaptation system.
+
+This module maps each Python adaptation function to a C adaptation function.
+Notice that C adaptation functions have a different signature because they can
+avoid making a memory copy, however this makes impossible to expose them to
+Python.
+
+This module exposes facilities to map the builtin adapters in python to
+equivalent C implementations.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any
+
+cimport cython
+
+from libc.string cimport memcpy, memchr
+from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
+from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_AS_STRING
+
+from psycopg_c.pq cimport _buffer_as_string_and_size, Escaping
+
+from psycopg import errors as e
+from psycopg.pq.misc import error_message
+
+
+@cython.freelist(8)
+cdef class CDumper:
+
+ cdef readonly object cls
+ cdef pq.PGconn _pgconn
+
+ oid = oids.INVALID_OID
+
+ def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+ self.cls = cls
+ conn = context.connection if context is not None else None
+ self._pgconn = conn.pgconn if conn is not None else None
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ """Store the Postgres representation *obj* into *rv* at *offset*
+
+ Return the number of bytes written to rv or -1 on Python exception.
+
+ Subclasses must implement this method. The `dump()` implementation
+ transforms the result of this method to a bytearray so that it can be
+ returned to Python.
+
+ The function interface allows C code to use this method automatically
+ to create larger buffers, e.g. for copy, composite objects, etc.
+
+ Implementation note: as you will always need to make sure that rv
+ has enough space to include what you want to dump, `ensure_size()`
+ might probably come handy.
+ """
+ raise NotImplementedError()
+
+ def dump(self, obj):
+ """Return the Postgres representation of *obj* as Python array of bytes"""
+ cdef rv = PyByteArray_FromStringAndSize("", 0)
+ cdef Py_ssize_t length = self.cdump(obj, rv, 0)
+ PyByteArray_Resize(rv, length)
+ return rv
+
+ def quote(self, obj):
+ cdef char *ptr
+ cdef char *ptr_out
+ cdef Py_ssize_t length
+
+ value = self.dump(obj)
+
+ if self._pgconn is not None:
+ esc = Escaping(self._pgconn)
+ # escaping and quoting
+ return esc.escape_literal(value)
+
+ # This path is taken when quote is asked without a connection,
+ # usually it means by psycopg.sql.quote() or by
+ # 'Composible.as_string(None)'. Most often than not this is done by
+ # someone generating a SQL file to consume elsewhere.
+
+ rv = PyByteArray_FromStringAndSize("", 0)
+
+ # No quoting, only quote escaping, random bs escaping. See further.
+ esc = Escaping()
+ out = esc.escape_string(value)
+
+ _buffer_as_string_and_size(out, &ptr, &length)
+
+ if not memchr(ptr, b'\\', length):
+ # If the string has no backslash, the result is correct and we
+ # don't need to bother with standard_conforming_strings.
+ PyByteArray_Resize(rv, length + 2) # Must include the quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b"'"
+ memcpy(ptr_out + 1, ptr, length)
+ ptr_out[length + 1] = b"'"
+ return rv
+
+ # The libpq has a crazy behaviour: PQescapeString uses the last
+ # standard_conforming_strings setting seen on a connection. This
+ # means that backslashes might be escaped or might not.
+ #
+ # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+ # if scs is off, '\\' raises a warning and '\' is an error.
+ #
+ # Check what the libpq does, and if it doesn't escape the backslash
+ # let's do it on our own. Never mind the race condition.
+ PyByteArray_Resize(rv, length + 4) # Must include " E'...'" quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + 3, ptr, length)
+ ptr_out[length + 3] = b"'"
+
+ if esc.escape_string(b"\\") == b"\\":
+ rv = bytes(rv).replace(b"\\", b"\\\\")
+ return rv
+
+ cpdef object get_key(self, object obj, object format):
+ return self.cls
+
+ cpdef object upgrade(self, object obj, object format):
+ return self
+
+ @staticmethod
+ cdef char *ensure_size(bytearray ba, Py_ssize_t offset, Py_ssize_t size) except NULL:
+ """
+ Grow *ba*, if necessary, to contains at least *size* bytes after *offset*
+
+ Return the pointer in the bytearray at *offset*, i.e. the place where
+ you want to write *size* bytes.
+ """
+ cdef Py_ssize_t curr_size = PyByteArray_GET_SIZE(ba)
+ cdef Py_ssize_t new_size = offset + size
+ if curr_size < new_size:
+ PyByteArray_Resize(ba, new_size)
+
+ return PyByteArray_AS_STRING(ba) + offset
+
+
+@cython.freelist(8)
+cdef class CLoader:
+ cdef public libpq.Oid oid
+ cdef pq.PGconn _pgconn
+
+ def __cinit__(self, int oid, context: Optional[AdaptContext] = None):
+ self.oid = oid
+ conn = context.connection if context is not None else None
+ self._pgconn = conn.pgconn if conn is not None else None
+
+ cdef object cload(self, const char *data, size_t length):
+ raise NotImplementedError()
+
+ def load(self, object data) -> Any:
+ cdef char *ptr
+ cdef Py_ssize_t length
+ _buffer_as_string_and_size(data, &ptr, &length)
+ return self.cload(ptr, length)
+
+
+cdef class _CRecursiveLoader(CLoader):
+
+ cdef Transformer _tx
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+ self._tx = Transformer.from_context(context)
diff --git a/psycopg_c/psycopg_c/_psycopg/copy.pyx b/psycopg_c/psycopg_c/_psycopg/copy.pyx
new file mode 100644
index 0000000..b943095
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/copy.pyx
@@ -0,0 +1,340 @@
+"""
+C optimised functions for the copy system.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from libc.string cimport memcpy
+from libc.stdint cimport uint16_t, uint32_t, int32_t
+from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
+from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE
+from cpython.memoryview cimport PyMemoryView_FromObject
+
+from psycopg_c._psycopg cimport endian
+from psycopg_c.pq cimport ViewBuffer
+
+from psycopg import errors as e
+
+cdef int32_t _binary_null = -1
+
+
+def format_row_binary(
+ row: Sequence[Any], tx: Transformer, out: bytearray = None
+) -> bytearray:
+ """Convert a row of adapted data to the data to send for binary copy"""
+ cdef Py_ssize_t rowlen = len(row)
+ cdef uint16_t berowlen = endian.htobe16(<int16_t>rowlen)
+
+ cdef Py_ssize_t pos # offset in 'out' where to write
+ if out is None:
+ out = PyByteArray_FromStringAndSize("", 0)
+ pos = 0
+ else:
+ pos = PyByteArray_GET_SIZE(out)
+
+ # let's start from a nice chunk
+ # (larger than most fixed size; for variable ones, oh well, we'll resize it)
+ cdef char *target = CDumper.ensure_size(
+ out, pos, sizeof(berowlen) + 20 * rowlen)
+
+ # Write the number of fields as network-order 16 bits
+ memcpy(target, <void *>&berowlen, sizeof(berowlen))
+ pos += sizeof(berowlen)
+
+ cdef Py_ssize_t size
+ cdef uint32_t besize
+ cdef char *buf
+ cdef int i
+ cdef PyObject *fmt = <PyObject *>PG_BINARY
+ cdef PyObject *row_dumper
+
+ if not tx._row_dumpers:
+ tx._row_dumpers = PyList_New(rowlen)
+
+ dumpers = tx._row_dumpers
+
+ for i in range(rowlen):
+ item = row[i]
+ if item is None:
+ target = CDumper.ensure_size(out, pos, sizeof(_binary_null))
+ memcpy(target, <void *>&_binary_null, sizeof(_binary_null))
+ pos += sizeof(_binary_null)
+ continue
+
+ row_dumper = PyList_GET_ITEM(dumpers, i)
+ if not row_dumper:
+ row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+ Py_INCREF(<object>row_dumper)
+ PyList_SET_ITEM(dumpers, i, <object>row_dumper)
+
+ if (<RowDumper>row_dumper).cdumper is not None:
+ # A cdumper can resize if necessary and copy in place
+ size = (<RowDumper>row_dumper).cdumper.cdump(
+ item, out, pos + sizeof(besize))
+ # Also add the size of the item, before the item
+ besize = endian.htobe32(<int32_t>size)
+ target = PyByteArray_AS_STRING(out) # might have been moved by cdump
+ memcpy(target + pos, <void *>&besize, sizeof(besize))
+ else:
+ # A Python dumper, gotta call it and extract its juices
+ b = PyObject_CallFunctionObjArgs(
+ (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL)
+ _buffer_as_string_and_size(b, &buf, &size)
+ target = CDumper.ensure_size(out, pos, size + sizeof(besize))
+ besize = endian.htobe32(<int32_t>size)
+ memcpy(target, <void *>&besize, sizeof(besize))
+ memcpy(target + sizeof(besize), buf, size)
+
+ pos += size + sizeof(besize)
+
+ # Resize to the final size
+ PyByteArray_Resize(out, pos)
+ return out
+
+
+def format_row_text(
+ row: Sequence[Any], tx: Transformer, out: bytearray = None
+) -> bytearray:
+ cdef Py_ssize_t pos # offset in 'out' where to write
+ if out is None:
+ out = PyByteArray_FromStringAndSize("", 0)
+ pos = 0
+ else:
+ pos = PyByteArray_GET_SIZE(out)
+
+ cdef Py_ssize_t rowlen = len(row)
+
+ if rowlen == 0:
+ PyByteArray_Resize(out, pos + 1)
+ out[pos] = b"\n"
+ return out
+
+ cdef Py_ssize_t size, tmpsize
+ cdef char *buf
+ cdef int i, j
+ cdef unsigned char *target
+ cdef int nesc = 0
+ cdef int with_tab
+ cdef PyObject *fmt = <PyObject *>PG_TEXT
+ cdef PyObject *row_dumper
+
+ for i in range(rowlen):
+ # Include the tab before the data, so it gets included in the resizes
+ with_tab = i > 0
+
+ item = row[i]
+ if item is None:
+ if with_tab:
+ target = <unsigned char *>CDumper.ensure_size(out, pos, 3)
+ memcpy(target, b"\t\\N", 3)
+ pos += 3
+ else:
+ target = <unsigned char *>CDumper.ensure_size(out, pos, 2)
+ memcpy(target, b"\\N", 2)
+ pos += 2
+ continue
+
+ row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+ if (<RowDumper>row_dumper).cdumper is not None:
+ # A cdumper can resize if necessary and copy in place
+ size = (<RowDumper>row_dumper).cdumper.cdump(
+ item, out, pos + with_tab)
+ target = <unsigned char *>PyByteArray_AS_STRING(out) + pos
+ else:
+ # A Python dumper, gotta call it and extract its juices
+ b = PyObject_CallFunctionObjArgs(
+ (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL)
+ _buffer_as_string_and_size(b, &buf, &size)
+ target = <unsigned char *>CDumper.ensure_size(out, pos, size + with_tab)
+ memcpy(target + with_tab, buf, size)
+
+ # Prepend a tab to the data just written
+ if with_tab:
+ target[0] = b"\t"
+ target += 1
+ pos += 1
+
+ # Now from pos to pos + size there is a textual representation: it may
+ # contain chars to escape. Scan to find how many such chars there are.
+ for j in range(size):
+ if copy_escape_lut[target[j]]:
+ nesc += 1
+
+ # If there is any char to escape, walk backwards pushing the chars
+ # forward and interspersing backslashes.
+ if nesc > 0:
+ tmpsize = size + nesc
+ target = <unsigned char *>CDumper.ensure_size(out, pos, tmpsize)
+ for j in range(<int>size - 1, -1, -1):
+ if copy_escape_lut[target[j]]:
+ target[j + nesc] = copy_escape_lut[target[j]]
+ nesc -= 1
+ target[j + nesc] = b"\\"
+ if nesc <= 0:
+ break
+ else:
+ target[j + nesc] = target[j]
+ pos += tmpsize
+ else:
+ pos += size
+
+ # Resize to the final size, add the newline
+ PyByteArray_Resize(out, pos + 1)
+ out[pos] = b"\n"
+ return out
+
+
+def parse_row_binary(data, tx: Transformer) -> Tuple[Any, ...]:
+ cdef unsigned char *ptr
+ cdef Py_ssize_t bufsize
+ _buffer_as_string_and_size(data, <char **>&ptr, &bufsize)
+ cdef unsigned char *bufend = ptr + bufsize
+
+ cdef uint16_t benfields = (<uint16_t *>ptr)[0]
+ cdef int nfields = endian.be16toh(benfields)
+ ptr += sizeof(benfields)
+ cdef list row = PyList_New(nfields)
+
+ cdef int col
+ cdef int32_t belength
+ cdef Py_ssize_t length
+
+ for col in range(nfields):
+ memcpy(&belength, ptr, sizeof(belength))
+ ptr += sizeof(belength)
+ if belength == _binary_null:
+ field = None
+ else:
+ length = endian.be32toh(belength)
+ if ptr + length > bufend:
+ raise e.DataError("bad copy data: length exceeding data")
+ field = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(data, ptr, length))
+ ptr += length
+
+ Py_INCREF(field)
+ PyList_SET_ITEM(row, col, field)
+
+ return tx.load_sequence(row)
+
+
+def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
+ cdef unsigned char *fstart
+ cdef Py_ssize_t size
+ _buffer_as_string_and_size(data, <char **>&fstart, &size)
+
+ # politely assume that the number of fields will be what in the result
+ cdef int nfields = tx._nfields
+ cdef list row = PyList_New(nfields)
+
+ cdef unsigned char *fend
+ cdef unsigned char *rowend = fstart + size
+ cdef unsigned char *src
+ cdef unsigned char *tgt
+ cdef int col
+ cdef int num_bs
+
+ for col in range(nfields):
+ fend = fstart
+ num_bs = 0
+ # Scan to the end of the field, remember if you see any backslash
+ while fend[0] != b'\t' and fend[0] != b'\n' and fend < rowend:
+ if fend[0] == b'\\':
+ num_bs += 1
+ # skip the next char to avoid counting escaped backslashes twice
+ fend += 1
+ fend += 1
+
+ # Check if we stopped for the right reason
+ if fend >= rowend:
+ raise e.DataError("bad copy data: field delimiter not found")
+ elif fend[0] == b'\t' and col == nfields - 1:
+ raise e.DataError("bad copy data: got a tab at the end of the row")
+ elif fend[0] == b'\n' and col != nfields - 1:
+ raise e.DataError(
+ "bad copy format: got a newline before the end of the row")
+
+ # Is this a NULL?
+ if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N':
+ field = None
+
+ # Is this a field with no backslash?
+ elif num_bs == 0:
+ # Nothing to unescape: we don't need a copy
+ field = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(data, fstart, fend - fstart))
+
+ # This is a field containing backslashes
+ else:
+ # We need a copy of the buffer to unescape
+ field = PyByteArray_FromStringAndSize("", 0)
+ PyByteArray_Resize(field, fend - fstart - num_bs)
+ tgt = <unsigned char *>PyByteArray_AS_STRING(field)
+ src = fstart
+ while (src < fend):
+ if src[0] != b'\\':
+ tgt[0] = src[0]
+ else:
+ src += 1
+ tgt[0] = copy_unescape_lut[src[0]]
+ src += 1
+ tgt += 1
+
+ Py_INCREF(field)
+ PyList_SET_ITEM(row, col, field)
+
+ # Start of the field
+ fstart = fend + 1
+
+ # Convert the array of buffers into Python objects
+ return tx.load_sequence(row)
+
+
+cdef extern from *:
+ """
+/* handle chars to (un)escape in text copy representation */
+/* '\b', '\t', '\n', '\v', '\f', '\r', '\\' */
+
+/* Escaping chars */
+static const char copy_escape_lut[] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 98, 116, 110, 118, 102, 114, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 92, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+};
+
+/* Conversion of escaped to unescaped chars */
+static const char copy_unescape_lut[] = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
+ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
+ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
+ 96, 97, 8, 99, 100, 101, 12, 103, 104, 105, 106, 107, 108, 109, 10, 111,
+112, 113, 13, 115, 9, 117, 11, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
+144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
+160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
+176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
+192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207,
+208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
+224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
+240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255,
+};
+ """
+ const char[256] copy_escape_lut
+ const char[256] copy_unescape_lut
diff --git a/psycopg_c/psycopg_c/_psycopg/endian.pxd b/psycopg_c/psycopg_c/_psycopg/endian.pxd
new file mode 100644
index 0000000..44e7305
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/endian.pxd
@@ -0,0 +1,155 @@
+"""
+Access to endian conversion function
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from libc.stdint cimport uint16_t, uint32_t, uint64_t
+
+cdef extern from * nogil:
+ # from https://gist.github.com/panzi/6856583
+ # Improved in:
+ # https://github.com/linux-sunxi/sunxi-tools/blob/master/include/portable_endian.h
+ """
+// "License": Public Domain
+// I, Mathias Panzenböck, place this file hereby into the public domain. Use it at your own risk for whatever you like.
+// In case there are jurisdictions that don't support putting things in the public domain you can also consider it to
+// be "dual licensed" under the BSD, MIT and Apache licenses, if you want to. This code is trivial anyway. Consider it
+// an example on how to get the endian conversion functions on different platforms.
+
+#ifndef PORTABLE_ENDIAN_H__
+#define PORTABLE_ENDIAN_H__
+
+#if (defined(_WIN16) || defined(_WIN32) || defined(_WIN64)) && !defined(__WINDOWS__)
+
+# define __WINDOWS__
+
+#endif
+
+#if defined(__linux__) || defined(__CYGWIN__)
+
+# include <endian.h>
+
+#elif defined(__APPLE__)
+
+# include <libkern/OSByteOrder.h>
+
+# define htobe16(x) OSSwapHostToBigInt16(x)
+# define htole16(x) OSSwapHostToLittleInt16(x)
+# define be16toh(x) OSSwapBigToHostInt16(x)
+# define le16toh(x) OSSwapLittleToHostInt16(x)
+
+# define htobe32(x) OSSwapHostToBigInt32(x)
+# define htole32(x) OSSwapHostToLittleInt32(x)
+# define be32toh(x) OSSwapBigToHostInt32(x)
+# define le32toh(x) OSSwapLittleToHostInt32(x)
+
+# define htobe64(x) OSSwapHostToBigInt64(x)
+# define htole64(x) OSSwapHostToLittleInt64(x)
+# define be64toh(x) OSSwapBigToHostInt64(x)
+# define le64toh(x) OSSwapLittleToHostInt64(x)
+
+# define __BYTE_ORDER BYTE_ORDER
+# define __BIG_ENDIAN BIG_ENDIAN
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __PDP_ENDIAN PDP_ENDIAN
+
+#elif defined(__OpenBSD__) || defined(__NetBSD__) || defined(__FreeBSD__) || defined(__DragonFly__)
+
+# include <sys/endian.h>
+
+/* For functions still missing, try to substitute 'historic' OpenBSD names */
+#ifndef be16toh
+# define be16toh(x) betoh16(x)
+#endif
+#ifndef le16toh
+# define le16toh(x) letoh16(x)
+#endif
+#ifndef be32toh
+# define be32toh(x) betoh32(x)
+#endif
+#ifndef le32toh
+# define le32toh(x) letoh32(x)
+#endif
+#ifndef be64toh
+# define be64toh(x) betoh64(x)
+#endif
+#ifndef le64toh
+# define le64toh(x) letoh64(x)
+#endif
+
+#elif defined(__WINDOWS__)
+
+# include <winsock2.h>
+# ifndef _MSC_VER
+# include <sys/param.h>
+# endif
+
+# if BYTE_ORDER == LITTLE_ENDIAN
+
+# define htobe16(x) htons(x)
+# define htole16(x) (x)
+# define be16toh(x) ntohs(x)
+# define le16toh(x) (x)
+
+# define htobe32(x) htonl(x)
+# define htole32(x) (x)
+# define be32toh(x) ntohl(x)
+# define le32toh(x) (x)
+
+# define htobe64(x) htonll(x)
+# define htole64(x) (x)
+# define be64toh(x) ntohll(x)
+# define le64toh(x) (x)
+
+# elif BYTE_ORDER == BIG_ENDIAN
+
+ /* that would be xbox 360 */
+# define htobe16(x) (x)
+# define htole16(x) __builtin_bswap16(x)
+# define be16toh(x) (x)
+# define le16toh(x) __builtin_bswap16(x)
+
+# define htobe32(x) (x)
+# define htole32(x) __builtin_bswap32(x)
+# define be32toh(x) (x)
+# define le32toh(x) __builtin_bswap32(x)
+
+# define htobe64(x) (x)
+# define htole64(x) __builtin_bswap64(x)
+# define be64toh(x) (x)
+# define le64toh(x) __builtin_bswap64(x)
+
+# else
+
+# error byte order not supported
+
+# endif
+
+# define __BYTE_ORDER BYTE_ORDER
+# define __BIG_ENDIAN BIG_ENDIAN
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __PDP_ENDIAN PDP_ENDIAN
+
+#else
+
+# error platform not supported
+
+#endif
+
+#endif
+ """
+ cdef uint16_t htobe16(uint16_t host_16bits)
+ cdef uint16_t htole16(uint16_t host_16bits)
+ cdef uint16_t be16toh(uint16_t big_endian_16bits)
+ cdef uint16_t le16toh(uint16_t little_endian_16bits)
+
+ cdef uint32_t htobe32(uint32_t host_32bits)
+ cdef uint32_t htole32(uint32_t host_32bits)
+ cdef uint32_t be32toh(uint32_t big_endian_32bits)
+ cdef uint32_t le32toh(uint32_t little_endian_32bits)
+
+ cdef uint64_t htobe64(uint64_t host_64bits)
+ cdef uint64_t htole64(uint64_t host_64bits)
+ cdef uint64_t be64toh(uint64_t big_endian_64bits)
+ cdef uint64_t le64toh(uint64_t little_endian_64bits)
diff --git a/psycopg_c/psycopg_c/_psycopg/generators.pyx b/psycopg_c/psycopg_c/_psycopg/generators.pyx
new file mode 100644
index 0000000..9ce9e54
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/generators.pyx
@@ -0,0 +1,276 @@
+"""
+C implementation of generators for the communication protocols with the libpq
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from cpython.object cimport PyObject_CallFunctionObjArgs
+
+from typing import List
+
+from psycopg import errors as e
+from psycopg.pq import abc, error_message
+from psycopg.abc import PipelineCommand, PQGen
+from psycopg._enums import Wait, Ready
+from psycopg._compat import Deque
+from psycopg._encodings import conninfo_encoding
+
+cdef object WAIT_W = Wait.W
+cdef object WAIT_R = Wait.R
+cdef object WAIT_RW = Wait.RW
+cdef object PY_READY_R = Ready.R
+cdef object PY_READY_W = Ready.W
+cdef object PY_READY_RW = Ready.RW
+cdef int READY_R = Ready.R
+cdef int READY_W = Ready.W
+cdef int READY_RW = Ready.RW
+
+def connect(conninfo: str) -> PQGenConn[abc.PGconn]:
+ """
+ Generator to create a database connection without blocking.
+
+ """
+ cdef pq.PGconn conn = pq.PGconn.connect_start(conninfo.encode())
+ cdef libpq.PGconn *pgconn_ptr = conn._pgconn_ptr
+ cdef int conn_status = libpq.PQstatus(pgconn_ptr)
+ cdef int poll_status
+
+ while True:
+ if conn_status == libpq.CONNECTION_BAD:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection is bad: {error_message(conn, encoding=encoding)}",
+ pgconn=conn
+ )
+
+ with nogil:
+ poll_status = libpq.PQconnectPoll(pgconn_ptr)
+
+ if poll_status == libpq.PGRES_POLLING_OK:
+ break
+ elif poll_status == libpq.PGRES_POLLING_READING:
+ yield (libpq.PQsocket(pgconn_ptr), WAIT_R)
+ elif poll_status == libpq.PGRES_POLLING_WRITING:
+ yield (libpq.PQsocket(pgconn_ptr), WAIT_W)
+ elif poll_status == libpq.PGRES_POLLING_FAILED:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection failed: {error_message(conn, encoding=encoding)}",
+ pgconn=conn
+ )
+ else:
+ raise e.InternalError(
+ f"unexpected poll status: {poll_status}", pgconn=conn
+ )
+
+ conn.nonblocking = 1
+ return conn
+
+
+def execute(pq.PGconn pgconn) -> PQGen[List[abc.PGresult]]:
+ """
+ Generator sending a query and returning results without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ yield from send(pgconn)
+ rv = yield from fetch_many(pgconn)
+ return rv
+
+
+def send(pq.PGconn pgconn) -> PQGen[None]:
+ """
+ Generator to send a query to the server without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ After this generator has finished you may want to cycle using `fetch()`
+ to retrieve the results available.
+ """
+ cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
+ cdef int status
+ cdef int cires
+
+ while True:
+ if pgconn.flush() == 0:
+ break
+
+ status = yield WAIT_RW
+ if status & READY_R:
+ with nogil:
+ # This call may read notifies which will be saved in the
+ # PGconn buffer and passed to Python later.
+ cires = libpq.PQconsumeInput(pgconn_ptr)
+ if 1 != cires:
+ raise e.OperationalError(
+ f"consuming input failed: {error_message(pgconn)}")
+
+
+def fetch_many(pq.PGconn pgconn) -> PQGen[List[PGresult]]:
+ """
+ Generator retrieving results from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ cdef list results = []
+ cdef int status
+ cdef pq.PGresult result
+ cdef libpq.PGresult *pgres
+
+ while True:
+ result = yield from fetch(pgconn)
+ if result is None:
+ break
+ results.append(result)
+ pgres = result._pgresult_ptr
+
+ status = libpq.PQresultStatus(pgres)
+ if (
+ status == libpq.PGRES_COPY_IN
+ or status == libpq.PGRES_COPY_OUT
+ or status == libpq.PGRES_COPY_BOTH
+ ):
+ # After entering copy mode the libpq will create a phony result
+ # for every request so let's break the endless loop.
+ break
+
+ if status == libpq.PGRES_PIPELINE_SYNC:
+ # PIPELINE_SYNC is not followed by a NULL, but we return it alone
+ # similarly to other result sets.
+ break
+
+ return results
+
+
+def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]:
+ """
+ Generator retrieving a single result from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return a result from the database (whether success or error).
+ """
+ cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
+ cdef int cires, ibres
+ cdef libpq.PGresult *pgres
+
+ with nogil:
+ ibres = libpq.PQisBusy(pgconn_ptr)
+ if ibres:
+ yield WAIT_R
+ while True:
+ with nogil:
+ cires = libpq.PQconsumeInput(pgconn_ptr)
+ if cires == 1:
+ ibres = libpq.PQisBusy(pgconn_ptr)
+
+ if 1 != cires:
+ raise e.OperationalError(
+ f"consuming input failed: {error_message(pgconn)}")
+ if not ibres:
+ break
+ yield WAIT_R
+
+ _consume_notifies(pgconn)
+
+ with nogil:
+ pgres = libpq.PQgetResult(pgconn_ptr)
+ if pgres is NULL:
+ return None
+ return pq.PGresult._from_ptr(pgres)
+
+
+def pipeline_communicate(
+ pq.PGconn pgconn, commands: Deque[PipelineCommand]
+) -> PQGen[List[List[PGresult]]]:
+ """Generator to send queries from a connection in pipeline mode while also
+ receiving results.
+
+ Return a list results, including single PIPELINE_SYNC elements.
+ """
+ cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
+ cdef int cires
+ cdef int status
+ cdef int ready
+ cdef libpq.PGresult *pgres
+ cdef list res = []
+ cdef list results = []
+ cdef pq.PGresult r
+
+ while True:
+ ready = yield WAIT_RW
+
+ if ready & READY_R:
+ with nogil:
+ cires = libpq.PQconsumeInput(pgconn_ptr)
+ if 1 != cires:
+ raise e.OperationalError(
+ f"consuming input failed: {error_message(pgconn)}")
+
+ _consume_notifies(pgconn)
+
+ res: List[PGresult] = []
+ while True:
+ with nogil:
+ ibres = libpq.PQisBusy(pgconn_ptr)
+ if ibres:
+ break
+ pgres = libpq.PQgetResult(pgconn_ptr)
+
+ if pgres is NULL:
+ if not res:
+ break
+ results.append(res)
+ res = []
+ else:
+ status = libpq.PQresultStatus(pgres)
+ r = pq.PGresult._from_ptr(pgres)
+ if status == libpq.PGRES_PIPELINE_SYNC:
+ results.append([r])
+ break
+ else:
+ res.append(r)
+
+ if ready & READY_W:
+ pgconn.flush()
+ if not commands:
+ break
+ commands.popleft()()
+
+ return results
+
+
+cdef int _consume_notifies(pq.PGconn pgconn) except -1:
+ cdef object notify_handler = pgconn.notify_handler
+ cdef libpq.PGconn *pgconn_ptr
+ cdef libpq.PGnotify *notify
+
+ if notify_handler is not None:
+ while True:
+ pynotify = pgconn.notifies()
+ if pynotify is None:
+ break
+ PyObject_CallFunctionObjArgs(
+ notify_handler, <PyObject *>pynotify, NULL
+ )
+ else:
+ pgconn_ptr = pgconn._pgconn_ptr
+ while True:
+ notify = libpq.PQnotifies(pgconn_ptr)
+ if notify is NULL:
+ break
+ libpq.PQfreemem(notify)
+
+ return 0
diff --git a/psycopg_c/psycopg_c/_psycopg/oids.pxd b/psycopg_c/psycopg_c/_psycopg/oids.pxd
new file mode 100644
index 0000000..2a864c4
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/oids.pxd
@@ -0,0 +1,92 @@
+"""
+Constants to refer to OIDS in C
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Use tools/update_oids.py to update this data.
+
+cdef enum:
+ INVALID_OID = 0
+
+ # autogenerated: start
+
+ # Generated from PostgreSQL 15.0
+
+ ACLITEM_OID = 1033
+ BIT_OID = 1560
+ BOOL_OID = 16
+ BOX_OID = 603
+ BPCHAR_OID = 1042
+ BYTEA_OID = 17
+ CHAR_OID = 18
+ CID_OID = 29
+ CIDR_OID = 650
+ CIRCLE_OID = 718
+ DATE_OID = 1082
+ DATEMULTIRANGE_OID = 4535
+ DATERANGE_OID = 3912
+ FLOAT4_OID = 700
+ FLOAT8_OID = 701
+ GTSVECTOR_OID = 3642
+ INET_OID = 869
+ INT2_OID = 21
+ INT2VECTOR_OID = 22
+ INT4_OID = 23
+ INT4MULTIRANGE_OID = 4451
+ INT4RANGE_OID = 3904
+ INT8_OID = 20
+ INT8MULTIRANGE_OID = 4536
+ INT8RANGE_OID = 3926
+ INTERVAL_OID = 1186
+ JSON_OID = 114
+ JSONB_OID = 3802
+ JSONPATH_OID = 4072
+ LINE_OID = 628
+ LSEG_OID = 601
+ MACADDR_OID = 829
+ MACADDR8_OID = 774
+ MONEY_OID = 790
+ NAME_OID = 19
+ NUMERIC_OID = 1700
+ NUMMULTIRANGE_OID = 4532
+ NUMRANGE_OID = 3906
+ OID_OID = 26
+ OIDVECTOR_OID = 30
+ PATH_OID = 602
+ PG_LSN_OID = 3220
+ POINT_OID = 600
+ POLYGON_OID = 604
+ RECORD_OID = 2249
+ REFCURSOR_OID = 1790
+ REGCLASS_OID = 2205
+ REGCOLLATION_OID = 4191
+ REGCONFIG_OID = 3734
+ REGDICTIONARY_OID = 3769
+ REGNAMESPACE_OID = 4089
+ REGOPER_OID = 2203
+ REGOPERATOR_OID = 2204
+ REGPROC_OID = 24
+ REGPROCEDURE_OID = 2202
+ REGROLE_OID = 4096
+ REGTYPE_OID = 2206
+ TEXT_OID = 25
+ TID_OID = 27
+ TIME_OID = 1083
+ TIMESTAMP_OID = 1114
+ TIMESTAMPTZ_OID = 1184
+ TIMETZ_OID = 1266
+ TSMULTIRANGE_OID = 4533
+ TSQUERY_OID = 3615
+ TSRANGE_OID = 3908
+ TSTZMULTIRANGE_OID = 4534
+ TSTZRANGE_OID = 3910
+ TSVECTOR_OID = 3614
+ TXID_SNAPSHOT_OID = 2970
+ UUID_OID = 2950
+ VARBIT_OID = 1562
+ VARCHAR_OID = 1043
+ XID_OID = 28
+ XID8_OID = 5069
+ XML_OID = 142
+ # autogenerated: end
diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx
new file mode 100644
index 0000000..fc69725
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx
@@ -0,0 +1,640 @@
+"""
+Helper object to transform values between Python and PostgreSQL
+
+Cython implementation: can access to lower level C features without creating
+too many temporary Python objects and performing less memory copying.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+from cpython.ref cimport Py_INCREF, Py_DECREF
+from cpython.set cimport PySet_Add, PySet_Contains
+from cpython.dict cimport PyDict_GetItem, PyDict_SetItem
+from cpython.list cimport (
+ PyList_New, PyList_CheckExact,
+ PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE)
+from cpython.bytes cimport PyBytes_AS_STRING
+from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
+from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
+
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
+
+from psycopg import errors as e
+from psycopg.pq import Format as PqFormat
+from psycopg.rows import Row, RowMaker
+from psycopg._encodings import pgconn_encoding
+
+NoneType = type(None)
+
+# internal structure: you are not supposed to know this. But it's worth some
+# 10% of the innermost loop, so I'm willing to ask for forgiveness later...
+
+ctypedef struct PGresAttValue:
+ int len
+ char *value
+
+ctypedef struct pg_result_int:
+ # NOTE: it would be advised that we don't know this structure's content
+ int ntups
+ int numAttributes
+ libpq.PGresAttDesc *attDescs
+ PGresAttValue **tuples
+ # ...more members, which we ignore
+
+
+@cython.freelist(16)
+cdef class RowLoader:
+ cdef CLoader cloader
+ cdef object pyloader
+ cdef object loadfunc
+
+
+@cython.freelist(16)
+cdef class RowDumper:
+ cdef CDumper cdumper
+ cdef object pydumper
+ cdef object dumpfunc
+ cdef object oid
+ cdef object format
+
+
+cdef class Transformer:
+ """
+ An object that can adapt efficiently between Python and PostgreSQL.
+
+ The life cycle of the object is the query, so it is assumed that attributes
+ such as the server version or the connection encoding will not change. The
+ object have its state so adapting several values of the same type can be
+ optimised.
+
+ """
+
+ cdef readonly object connection
+ cdef readonly object adapters
+ cdef readonly object types
+ cdef readonly object formats
+ cdef str _encoding
+ cdef int _none_oid
+
+ # mapping class -> Dumper instance (auto, text, binary)
+ cdef dict _auto_dumpers
+ cdef dict _text_dumpers
+ cdef dict _binary_dumpers
+
+ # mapping oid -> Loader instance (text, binary)
+ cdef dict _text_loaders
+ cdef dict _binary_loaders
+
+ # mapping oid -> Dumper instance (text, binary)
+ cdef dict _oid_text_dumpers
+ cdef dict _oid_binary_dumpers
+
+ cdef pq.PGresult _pgresult
+ cdef int _nfields, _ntuples
+ cdef list _row_dumpers
+ cdef list _row_loaders
+
+ cdef dict _oid_types
+
+ def __cinit__(self, context: Optional["AdaptContext"] = None):
+ if context is not None:
+ self.adapters = context.adapters
+ self.connection = context.connection
+ else:
+ from psycopg import postgres
+ self.adapters = postgres.adapters
+ self.connection = None
+
+ self.types = self.formats = None
+ self._none_oid = -1
+
+ @classmethod
+ def from_context(cls, context: Optional["AdaptContext"]):
+ """
+ Return a Transformer from an AdaptContext.
+
+ If the context is a Transformer instance, just return it.
+ """
+ return _tx_from_context(context)
+
+ @property
+ def encoding(self) -> str:
+ if not self._encoding:
+ conn = self.connection
+ self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
+ return self._encoding
+
+ @property
+ def pgresult(self) -> Optional[PGresult]:
+ return self._pgresult
+
+ cpdef set_pgresult(
+ self,
+ pq.PGresult result,
+ object set_loaders = True,
+ object format = None
+ ):
+ self._pgresult = result
+
+ if result is None:
+ self._nfields = self._ntuples = 0
+ if set_loaders:
+ self._row_loaders = []
+ return
+
+ cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
+ self._nfields = libpq.PQnfields(res)
+ self._ntuples = libpq.PQntuples(res)
+
+ if not set_loaders:
+ return
+
+ if not self._nfields:
+ self._row_loaders = []
+ return
+
+ if format is None:
+ format = libpq.PQfformat(res, 0)
+
+ cdef list loaders = PyList_New(self._nfields)
+ cdef PyObject *row_loader
+ cdef object oid
+
+ cdef int i
+ for i in range(self._nfields):
+ oid = libpq.PQftype(res, i)
+ row_loader = self._c_get_loader(<PyObject *>oid, <PyObject *>format)
+ Py_INCREF(<object>row_loader)
+ PyList_SET_ITEM(loaders, i, <object>row_loader)
+
+ self._row_loaders = loaders
+
+ def set_dumper_types(self, types: Sequence[int], format: Format) -> None:
+ cdef Py_ssize_t ntypes = len(types)
+ dumpers = PyList_New(ntypes)
+ cdef int i
+ for i in range(ntypes):
+ oid = types[i]
+ dumper_ptr = self.get_dumper_by_oid(
+ <PyObject *>oid, <PyObject *>format)
+ Py_INCREF(<object>dumper_ptr)
+ PyList_SET_ITEM(dumpers, i, <object>dumper_ptr)
+
+ self._row_dumpers = dumpers
+ self.types = tuple(types)
+ self.formats = [format] * ntypes
+
+ def set_loader_types(self, types: Sequence[int], format: Format) -> None:
+ self._c_loader_types(len(types), types, format)
+
+ cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, object format):
+ cdef list loaders = PyList_New(ntypes)
+
+ # these are used more as Python object than C
+ cdef PyObject *oid
+ cdef PyObject *row_loader
+ for i in range(ntypes):
+ oid = PyList_GET_ITEM(types, i)
+ row_loader = self._c_get_loader(oid, <PyObject *>format)
+ Py_INCREF(<object>row_loader)
+ PyList_SET_ITEM(loaders, i, <object>row_loader)
+
+ self._row_loaders = loaders
+
+ cpdef as_literal(self, obj):
+ cdef PyObject *row_dumper = self.get_row_dumper(
+ <PyObject *>obj, <PyObject *>PG_TEXT)
+
+ if (<RowDumper>row_dumper).cdumper is not None:
+ dumper = (<RowDumper>row_dumper).cdumper
+ else:
+ dumper = (<RowDumper>row_dumper).pydumper
+
+ rv = dumper.quote(obj)
+ oid = dumper.oid
+ # If the result is quoted and the oid not unknown or text,
+ # add an explicit type cast.
+ # Check the last char because the first one might be 'E'.
+ if oid and oid != oids.TEXT_OID and rv and rv[-1] == 39:
+ if self._oid_types is None:
+ self._oid_types = {}
+ type_ptr = PyDict_GetItem(<object>self._oid_types, oid)
+ if type_ptr == NULL:
+ type_sql = b""
+ ti = self.adapters.types.get(oid)
+ if ti is not None:
+ if oid < 8192:
+ # builtin: prefer "timestamptz" to "timestamp with time zone"
+ type_sql = ti.name.encode(self.encoding)
+ else:
+ type_sql = ti.regtype.encode(self.encoding)
+ if oid == ti.array_oid:
+ type_sql += b"[]"
+
+ type_ptr = <PyObject *>type_sql
+ PyDict_SetItem(<object>self._oid_types, oid, type_sql)
+
+ if <object>type_ptr:
+ rv = b"%s::%s" % (rv, <object>type_ptr)
+
+ return rv
+
+ def get_dumper(self, obj, format) -> "Dumper":
+ cdef PyObject *row_dumper = self.get_row_dumper(
+ <PyObject *>obj, <PyObject *>format)
+ return (<RowDumper>row_dumper).pydumper
+
+ cdef PyObject *get_row_dumper(self, PyObject *obj, PyObject *fmt) except NULL:
+ """
+ Return a borrowed reference to the RowDumper for the given obj/fmt.
+ """
+ # Fast path: return a Dumper class already instantiated from the same type
+ cdef PyObject *cache
+ cdef PyObject *ptr
+ cdef PyObject *ptr1
+ cdef RowDumper row_dumper
+
+ # Normally, the type of the object dictates how to dump it
+ key = type(<object>obj)
+
+ # Establish where would the dumper be cached
+ bfmt = PyUnicode_AsUTF8String(<object>fmt)
+ cdef char cfmt = PyBytes_AS_STRING(bfmt)[0]
+ if cfmt == b's':
+ if self._auto_dumpers is None:
+ self._auto_dumpers = {}
+ cache = <PyObject *>self._auto_dumpers
+ elif cfmt == b'b':
+ if self._binary_dumpers is None:
+ self._binary_dumpers = {}
+ cache = <PyObject *>self._binary_dumpers
+ elif cfmt == b't':
+ if self._text_dumpers is None:
+ self._text_dumpers = {}
+ cache = <PyObject *>self._text_dumpers
+ else:
+ raise ValueError(
+ f"format should be a psycopg.adapt.Format, not {<object>fmt}")
+
+ # Reuse an existing Dumper class for objects of the same type
+ ptr = PyDict_GetItem(<object>cache, key)
+ if ptr == NULL:
+ dcls = PyObject_CallFunctionObjArgs(
+ self.adapters.get_dumper, <PyObject *>key, fmt, NULL)
+ dumper = PyObject_CallFunctionObjArgs(
+ dcls, <PyObject *>key, <PyObject *>self, NULL)
+
+ row_dumper = _as_row_dumper(dumper)
+ PyDict_SetItem(<object>cache, key, row_dumper)
+ ptr = <PyObject *>row_dumper
+
+ # Check if the dumper requires an upgrade to handle this specific value
+ if (<RowDumper>ptr).cdumper is not None:
+ key1 = (<RowDumper>ptr).cdumper.get_key(<object>obj, <object>fmt)
+ else:
+ key1 = PyObject_CallFunctionObjArgs(
+ (<RowDumper>ptr).pydumper.get_key, obj, fmt, NULL)
+ if key1 is key:
+ return ptr
+
+ # If it does, ask the dumper to create its own upgraded version
+ ptr1 = PyDict_GetItem(<object>cache, key1)
+ if ptr1 != NULL:
+ return ptr1
+
+ if (<RowDumper>ptr).cdumper is not None:
+ dumper = (<RowDumper>ptr).cdumper.upgrade(<object>obj, <object>fmt)
+ else:
+ dumper = PyObject_CallFunctionObjArgs(
+ (<RowDumper>ptr).pydumper.upgrade, obj, fmt, NULL)
+
+ row_dumper = _as_row_dumper(dumper)
+ PyDict_SetItem(<object>cache, key1, row_dumper)
+ return <PyObject *>row_dumper
+
+ cdef PyObject *get_dumper_by_oid(self, PyObject *oid, PyObject *fmt) except NULL:
+ """
+ Return a borrowed reference to the RowDumper for the given oid/fmt.
+ """
+ cdef PyObject *ptr
+ cdef PyObject *cache
+ cdef RowDumper row_dumper
+
+ # Establish where would the dumper be cached
+ cdef int cfmt = <object>fmt
+ if cfmt == 0:
+ if self._oid_text_dumpers is None:
+ self._oid_text_dumpers = {}
+ cache = <PyObject *>self._oid_text_dumpers
+ elif cfmt == 1:
+ if self._oid_binary_dumpers is None:
+ self._oid_binary_dumpers = {}
+ cache = <PyObject *>self._oid_binary_dumpers
+ else:
+ raise ValueError(
+ f"format should be a psycopg.pq.Format, not {<object>fmt}")
+
+ # Reuse an existing Dumper class for objects of the same type
+ ptr = PyDict_GetItem(<object>cache, <object>oid)
+ if ptr == NULL:
+ dcls = PyObject_CallFunctionObjArgs(
+ self.adapters.get_dumper_by_oid, oid, fmt, NULL)
+ dumper = PyObject_CallFunctionObjArgs(
+ dcls, <PyObject *>NoneType, <PyObject *>self, NULL)
+
+ row_dumper = _as_row_dumper(dumper)
+ PyDict_SetItem(<object>cache, <object>oid, row_dumper)
+ ptr = <PyObject *>row_dumper
+
+ return ptr
+
+ cpdef dump_sequence(self, object params, object formats):
+ # Verify that they are not none and that PyList_GET_ITEM won't blow up
+ cdef Py_ssize_t nparams = len(params)
+ cdef list out = PyList_New(nparams)
+
+ cdef int i
+ cdef PyObject *dumper_ptr # borrowed pointer to row dumper
+ cdef object dumped
+ cdef Py_ssize_t size
+
+ if self._none_oid < 0:
+ self._none_oid = self.adapters.get_dumper(NoneType, "s").oid
+
+ dumpers = self._row_dumpers
+
+ if dumpers:
+ for i in range(nparams):
+ param = params[i]
+ if param is not None:
+ dumper_ptr = PyList_GET_ITEM(dumpers, i)
+ if (<RowDumper>dumper_ptr).cdumper is not None:
+ dumped = PyByteArray_FromStringAndSize("", 0)
+ size = (<RowDumper>dumper_ptr).cdumper.cdump(
+ param, <bytearray>dumped, 0)
+ PyByteArray_Resize(dumped, size)
+ else:
+ dumped = PyObject_CallFunctionObjArgs(
+ (<RowDumper>dumper_ptr).dumpfunc,
+ <PyObject *>param, NULL)
+ else:
+ dumped = None
+
+ Py_INCREF(dumped)
+ PyList_SET_ITEM(out, i, dumped)
+
+ return out
+
+ cdef tuple types = PyTuple_New(nparams)
+ cdef list pqformats = PyList_New(nparams)
+
+ for i in range(nparams):
+ param = params[i]
+ if param is not None:
+ dumper_ptr = self.get_row_dumper(
+ <PyObject *>param, <PyObject *>formats[i])
+ if (<RowDumper>dumper_ptr).cdumper is not None:
+ dumped = PyByteArray_FromStringAndSize("", 0)
+ size = (<RowDumper>dumper_ptr).cdumper.cdump(
+ param, <bytearray>dumped, 0)
+ PyByteArray_Resize(dumped, size)
+ else:
+ dumped = PyObject_CallFunctionObjArgs(
+ (<RowDumper>dumper_ptr).dumpfunc,
+ <PyObject *>param, NULL)
+ oid = (<RowDumper>dumper_ptr).oid
+ fmt = (<RowDumper>dumper_ptr).format
+ else:
+ dumped = None
+ oid = self._none_oid
+ fmt = PQ_TEXT
+
+ Py_INCREF(dumped)
+ PyList_SET_ITEM(out, i, dumped)
+
+ Py_INCREF(oid)
+ PyTuple_SET_ITEM(types, i, oid)
+
+ Py_INCREF(fmt)
+ PyList_SET_ITEM(pqformats, i, fmt)
+
+ self.types = types
+ self.formats = pqformats
+ return out
+
+ def load_rows(self, int row0, int row1, object make_row) -> List[Row]:
+ if self._pgresult is None:
+ raise e.InterfaceError("result not set")
+
+ if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
+ raise e.InterfaceError(
+ f"rows must be included between 0 and {self._ntuples}"
+ )
+
+ cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
+ # cheeky access to the internal PGresult structure
+ cdef pg_result_int *ires = <pg_result_int*>res
+
+ cdef int row
+ cdef int col
+ cdef PGresAttValue *attval
+ cdef object record # not 'tuple' as it would check on assignment
+
+ cdef object records = PyList_New(row1 - row0)
+ for row in range(row0, row1):
+ record = PyTuple_New(self._nfields)
+ Py_INCREF(record)
+ PyList_SET_ITEM(records, row - row0, record)
+
+ cdef PyObject *loader # borrowed RowLoader
+ cdef PyObject *brecord # borrowed
+ row_loaders = self._row_loaders # avoid an incref/decref per item
+
+ for col in range(self._nfields):
+ loader = PyList_GET_ITEM(row_loaders, col)
+ if (<RowLoader>loader).cloader is not None:
+ for row in range(row0, row1):
+ brecord = PyList_GET_ITEM(records, row - row0)
+ attval = &(ires.tuples[row][col])
+ if attval.len == -1: # NULL_LEN
+ pyval = None
+ else:
+ pyval = (<RowLoader>loader).cloader.cload(
+ attval.value, attval.len)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(<object>brecord, col, pyval)
+
+ else:
+ for row in range(row0, row1):
+ brecord = PyList_GET_ITEM(records, row - row0)
+ attval = &(ires.tuples[row][col])
+ if attval.len == -1: # NULL_LEN
+ pyval = None
+ else:
+ b = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(
+ self._pgresult,
+ <unsigned char *>attval.value, attval.len))
+ pyval = PyObject_CallFunctionObjArgs(
+ (<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(<object>brecord, col, pyval)
+
+ if make_row is not tuple:
+ for i in range(row1 - row0):
+ brecord = PyList_GET_ITEM(records, i)
+ record = PyObject_CallFunctionObjArgs(
+ make_row, <PyObject *>brecord, NULL)
+ Py_INCREF(record)
+ PyList_SET_ITEM(records, i, record)
+ Py_DECREF(<object>brecord)
+ return records
+
+ def load_row(self, int row, object make_row) -> Optional[Row]:
+ if self._pgresult is None:
+ return None
+
+ if not 0 <= row < self._ntuples:
+ return None
+
+ cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
+ # cheeky access to the internal PGresult structure
+ cdef pg_result_int *ires = <pg_result_int*>res
+
+ cdef PyObject *loader # borrowed RowLoader
+ cdef int col
+ cdef PGresAttValue *attval
+ cdef object record # not 'tuple' as it would check on assignment
+
+ record = PyTuple_New(self._nfields)
+ row_loaders = self._row_loaders # avoid an incref/decref per item
+
+ for col in range(self._nfields):
+ attval = &(ires.tuples[row][col])
+ if attval.len == -1: # NULL_LEN
+ pyval = None
+ else:
+ loader = PyList_GET_ITEM(row_loaders, col)
+ if (<RowLoader>loader).cloader is not None:
+ pyval = (<RowLoader>loader).cloader.cload(
+ attval.value, attval.len)
+ else:
+ b = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(
+ self._pgresult,
+ <unsigned char *>attval.value, attval.len))
+ pyval = PyObject_CallFunctionObjArgs(
+ (<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(record, col, pyval)
+
+ if make_row is not tuple:
+ record = PyObject_CallFunctionObjArgs(
+ make_row, <PyObject *>record, NULL)
+ return record
+
+ cpdef object load_sequence(self, record: Sequence[Optional[Buffer]]):
+ cdef Py_ssize_t nfields = len(record)
+ out = PyTuple_New(nfields)
+ cdef PyObject *loader # borrowed RowLoader
+ cdef int col
+ cdef char *ptr
+ cdef Py_ssize_t size
+
+ row_loaders = self._row_loaders # avoid an incref/decref per item
+ if PyList_GET_SIZE(row_loaders) != nfields:
+ raise e.ProgrammingError(
+ f"cannot load sequence of {nfields} items:"
+ f" {len(self._row_loaders)} loaders registered")
+
+ for col in range(nfields):
+ item = record[col]
+ if item is None:
+ Py_INCREF(None)
+ PyTuple_SET_ITEM(out, col, None)
+ continue
+
+ loader = PyList_GET_ITEM(row_loaders, col)
+ if (<RowLoader>loader).cloader is not None:
+ _buffer_as_string_and_size(item, &ptr, &size)
+ pyval = (<RowLoader>loader).cloader.cload(ptr, size)
+ else:
+ pyval = PyObject_CallFunctionObjArgs(
+ (<RowLoader>loader).loadfunc, <PyObject *>item, NULL)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(out, col, pyval)
+
+ return out
+
+ def get_loader(self, oid: int, format: pq.Format) -> "Loader":
+ cdef PyObject *row_loader = self._c_get_loader(
+ <PyObject *>oid, <PyObject *>format)
+ return (<RowLoader>row_loader).pyloader
+
+ cdef PyObject *_c_get_loader(self, PyObject *oid, PyObject *fmt) except NULL:
+ """
+ Return a borrowed reference to the RowLoader instance for given oid/fmt
+ """
+ cdef PyObject *ptr
+ cdef PyObject *cache
+
+ if <object>fmt == PQ_TEXT:
+ if self._text_loaders is None:
+ self._text_loaders = {}
+ cache = <PyObject *>self._text_loaders
+ elif <object>fmt == PQ_BINARY:
+ if self._binary_loaders is None:
+ self._binary_loaders = {}
+ cache = <PyObject *>self._binary_loaders
+ else:
+ raise ValueError(
+ f"format should be a psycopg.pq.Format, not {format}")
+
+ ptr = PyDict_GetItem(<object>cache, <object>oid)
+ if ptr != NULL:
+ return ptr
+
+ loader_cls = self.adapters.get_loader(<object>oid, <object>fmt)
+ if loader_cls is None:
+ loader_cls = self.adapters.get_loader(oids.INVALID_OID, <object>fmt)
+ if loader_cls is None:
+ raise e.InterfaceError("unknown oid loader not found")
+
+ loader = PyObject_CallFunctionObjArgs(
+ loader_cls, oid, <PyObject *>self, NULL)
+
+ cdef RowLoader row_loader = RowLoader()
+ row_loader.pyloader = loader
+ row_loader.loadfunc = loader.load
+ if isinstance(loader, CLoader):
+ row_loader.cloader = <CLoader>loader
+
+ PyDict_SetItem(<object>cache, <object>oid, row_loader)
+ return <PyObject *>row_loader
+
+
+cdef object _as_row_dumper(object dumper):
+ cdef RowDumper row_dumper = RowDumper()
+
+ row_dumper.pydumper = dumper
+ row_dumper.dumpfunc = dumper.dump
+ row_dumper.oid = dumper.oid
+ row_dumper.format = dumper.format
+
+ if isinstance(dumper, CDumper):
+ row_dumper.cdumper = <CDumper>dumper
+
+ return row_dumper
+
+
+cdef Transformer _tx_from_context(object context):
+ if isinstance(context, Transformer):
+ return context
+ else:
+ return Transformer(context)
diff --git a/psycopg_c/psycopg_c/_psycopg/waiting.pyx b/psycopg_c/psycopg_c/_psycopg/waiting.pyx
new file mode 100644
index 0000000..0af6c57
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/waiting.pyx
@@ -0,0 +1,197 @@
+"""
+C implementation of waiting functions
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from cpython.object cimport PyObject_CallFunctionObjArgs
+
+cdef extern from *:
+ """
+#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
+
+#if defined(HAVE_POLL_H)
+#include <poll.h>
+#elif defined(HAVE_SYS_POLL_H)
+#include <sys/poll.h>
+#endif
+
+#else /* no poll available */
+
+#ifdef MS_WINDOWS
+#include <winsock2.h>
+#else
+#include <sys/select.h>
+#endif
+
+#endif /* HAVE_POLL */
+
+#define SELECT_EV_READ 1
+#define SELECT_EV_WRITE 2
+
+#define SEC_TO_MS 1000
+#define SEC_TO_US (1000 * 1000)
+
+/* Use select to wait for readiness on fileno.
+ *
+ * - Return SELECT_EV_* if the file is ready
+ * - Return 0 on timeout
+ * - Return -1 (and set an exception) on error.
+ *
+ * The wisdom of this function comes from:
+ *
+ * - PostgreSQL libpq (see src/interfaces/libpq/fe-misc.c)
+ * - Python select module (see Modules/selectmodule.c)
+ */
+static int
+wait_c_impl(int fileno, int wait, float timeout)
+{
+ int select_rv;
+ int rv = 0;
+
+#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
+
+ struct pollfd input_fd;
+ int timeout_ms;
+
+ input_fd.fd = fileno;
+ input_fd.events = POLLERR;
+ input_fd.revents = 0;
+
+ if (wait & SELECT_EV_READ) { input_fd.events |= POLLIN; }
+ if (wait & SELECT_EV_WRITE) { input_fd.events |= POLLOUT; }
+
+ if (timeout < 0.0) {
+ timeout_ms = -1;
+ } else {
+ timeout_ms = (int)(timeout * SEC_TO_MS);
+ }
+
+ Py_BEGIN_ALLOW_THREADS
+ errno = 0;
+ select_rv = poll(&input_fd, 1, timeout_ms);
+ Py_END_ALLOW_THREADS
+
+ if (PyErr_CheckSignals()) { goto finally; }
+
+ if (select_rv < 0) {
+ goto error;
+ }
+
+ if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
+ if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+
+#else
+
+ fd_set ifds;
+ fd_set ofds;
+ fd_set efds;
+ struct timeval tv, *tvptr;
+
+#ifndef MS_WINDOWS
+ if (fileno >= 1024) {
+ PyErr_SetString(
+ PyExc_ValueError, /* same exception of Python's 'select.select()' */
+ "connection file descriptor out of range for 'select()'");
+ return -1;
+ }
+#endif
+
+ FD_ZERO(&ifds);
+ FD_ZERO(&ofds);
+ FD_ZERO(&efds);
+
+ if (wait & SELECT_EV_READ) { FD_SET(fileno, &ifds); }
+ if (wait & SELECT_EV_WRITE) { FD_SET(fileno, &ofds); }
+ FD_SET(fileno, &efds);
+
+ /* Compute appropriate timeout interval */
+ if (timeout < 0.0) {
+ tvptr = NULL;
+ }
+ else {
+ tv.tv_sec = (int)timeout;
+ tv.tv_usec = (int)(((long)timeout * SEC_TO_US) % SEC_TO_US);
+ tvptr = &tv;
+ }
+
+ Py_BEGIN_ALLOW_THREADS
+ errno = 0;
+ select_rv = select(fileno + 1, &ifds, &ofds, &efds, tvptr);
+ Py_END_ALLOW_THREADS
+
+ if (PyErr_CheckSignals()) { goto finally; }
+
+ if (select_rv < 0) {
+ goto error;
+ }
+
+ if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
+ if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+
+#endif /* HAVE_POLL */
+
+ return rv;
+
+error:
+
+#ifdef MS_WINDOWS
+ if (select_rv == SOCKET_ERROR) {
+ PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError());
+ }
+#else
+ if (select_rv < 0) {
+ PyErr_SetFromErrno(PyExc_OSError);
+ }
+#endif
+ else {
+ PyErr_SetString(PyExc_OSError, "unexpected error from select()");
+ }
+
+finally:
+
+ return -1;
+
+}
+ """
+ cdef int wait_c_impl(int fileno, int wait, float timeout) except -1
+
+
+def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV:
+ """
+ Wait for a generator using poll or select.
+ """
+ cdef float ctimeout
+ cdef int wait, ready
+ cdef PyObject *pyready
+
+ if timeout is None:
+ ctimeout = -1.0
+ else:
+ ctimeout = float(timeout)
+ if ctimeout < 0.0:
+ ctimeout = -1.0
+
+ send = gen.send
+
+ try:
+ wait = next(gen)
+
+ while True:
+ ready = wait_c_impl(fileno, wait, ctimeout)
+ if ready == 0:
+ continue
+ elif ready == READY_R:
+ pyready = <PyObject *>PY_READY_R
+ elif ready == READY_RW:
+ pyready = <PyObject *>PY_READY_RW
+ elif ready == READY_W:
+ pyready = <PyObject *>PY_READY_W
+ else:
+ raise AssertionError(f"unexpected ready value: {ready}")
+
+ wait = PyObject_CallFunctionObjArgs(send, pyready, NULL)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
diff --git a/psycopg_c/psycopg_c/pq.pxd b/psycopg_c/psycopg_c/pq.pxd
new file mode 100644
index 0000000..57825dd
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq.pxd
@@ -0,0 +1,78 @@
+# Include pid_t but Windows doesn't have it
+# Don't use "IF" so that the generated C is portable and can be included
+# in the sdist.
+cdef extern from * nogil:
+ """
+#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS)
+ typedef signed pid_t;
+#else
+ #include <fcntl.h>
+#endif
+ """
+ ctypedef signed pid_t
+
+from psycopg_c.pq cimport libpq
+
+ctypedef char *(*conn_bytes_f) (const libpq.PGconn *)
+ctypedef int(*conn_int_f) (const libpq.PGconn *)
+
+
+cdef class PGconn:
+ cdef libpq.PGconn* _pgconn_ptr
+ cdef object __weakref__
+ cdef public object notice_handler
+ cdef public object notify_handler
+ cdef pid_t _procpid
+
+ @staticmethod
+ cdef PGconn _from_ptr(libpq.PGconn *ptr)
+
+ cpdef int flush(self) except -1
+ cpdef object notifies(self)
+
+
+cdef class PGresult:
+ cdef libpq.PGresult* _pgresult_ptr
+
+ @staticmethod
+ cdef PGresult _from_ptr(libpq.PGresult *ptr)
+
+
+cdef class PGcancel:
+ cdef libpq.PGcancel* pgcancel_ptr
+
+ @staticmethod
+ cdef PGcancel _from_ptr(libpq.PGcancel *ptr)
+
+
+cdef class Escaping:
+ cdef PGconn conn
+
+ cpdef escape_literal(self, data)
+ cpdef escape_identifier(self, data)
+ cpdef escape_string(self, data)
+ cpdef escape_bytea(self, data)
+ cpdef unescape_bytea(self, const unsigned char *data)
+
+
+cdef class PQBuffer:
+ cdef unsigned char *buf
+ cdef Py_ssize_t len
+
+ @staticmethod
+ cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length)
+
+
+cdef class ViewBuffer:
+ cdef unsigned char *buf
+ cdef Py_ssize_t len
+ cdef object obj
+
+ @staticmethod
+ cdef ViewBuffer _from_buffer(
+ object obj, unsigned char *buf, Py_ssize_t length)
+
+
+cdef int _buffer_as_string_and_size(
+ data: "Buffer", char **ptr, Py_ssize_t *length
+) except -1
diff --git a/psycopg_c/psycopg_c/pq.pyx b/psycopg_c/psycopg_c/pq.pyx
new file mode 100644
index 0000000..d397c17
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq.pyx
@@ -0,0 +1,38 @@
+"""
+libpq Python wrapper using cython bindings.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c.pq cimport libpq
+
+import logging
+
+from psycopg import errors as e
+from psycopg.pq import Format
+from psycopg.pq.misc import error_message
+
+logger = logging.getLogger("psycopg")
+
+__impl__ = 'c'
+__build_version__ = libpq.PG_VERSION_NUM
+
+
+def version():
+ return libpq.PQlibVersion()
+
+
+include "pq/pgconn.pyx"
+include "pq/pgresult.pyx"
+include "pq/pgcancel.pyx"
+include "pq/conninfo.pyx"
+include "pq/escaping.pyx"
+include "pq/pqbuffer.pyx"
+
+
+# importing the ssl module sets up Python's libcrypto callbacks
+import ssl # noqa
+
+# disable libcrypto setup in libpq, so it won't stomp on the callbacks
+# that have already been set up
+libpq.PQinitOpenSSL(1, 0)
diff --git a/psycopg_c/psycopg_c/pq/__init__.pxd b/psycopg_c/psycopg_c/pq/__init__.pxd
new file mode 100644
index 0000000..ce8c60c
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/__init__.pxd
@@ -0,0 +1,9 @@
+"""
+psycopg_c.pq cython module.
+
+This file is necessary to allow c-importing pxd files from this directory.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c.pq cimport libpq
diff --git a/psycopg_c/psycopg_c/pq/conninfo.pyx b/psycopg_c/psycopg_c/pq/conninfo.pyx
new file mode 100644
index 0000000..3443de1
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/conninfo.pyx
@@ -0,0 +1,61 @@
+"""
+psycopg_c.pq.Conninfo object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg.pq.misc import ConninfoOption
+
+
+class Conninfo:
+ @classmethod
+ def get_defaults(cls) -> List[ConninfoOption]:
+ cdef libpq.PQconninfoOption *opts = libpq.PQconndefaults()
+ if opts is NULL :
+ raise MemoryError("couldn't allocate connection defaults")
+ rv = _options_from_array(opts)
+ libpq.PQconninfoFree(opts)
+ return rv
+
+ @classmethod
+ def parse(cls, const char *conninfo) -> List[ConninfoOption]:
+ cdef char *errmsg = NULL
+ cdef libpq.PQconninfoOption *opts = libpq.PQconninfoParse(conninfo, &errmsg)
+ if opts is NULL:
+ if errmsg is NULL:
+ raise MemoryError("couldn't allocate on conninfo parse")
+ else:
+ exc = e.OperationalError(errmsg.decode("utf8", "replace"))
+ libpq.PQfreemem(errmsg)
+ raise exc
+
+ rv = _options_from_array(opts)
+ libpq.PQconninfoFree(opts)
+ return rv
+
+ def __repr__(self):
+ return f"<{type(self).__name__} ({self.keyword.decode('ascii')})>"
+
+
+cdef _options_from_array(libpq.PQconninfoOption *opts):
+ rv = []
+ cdef int i = 0
+ cdef libpq.PQconninfoOption* opt
+ while True:
+ opt = opts + i
+ if opt.keyword is NULL:
+ break
+ rv.append(
+ ConninfoOption(
+ keyword=opt.keyword,
+ envvar=opt.envvar if opt.envvar is not NULL else None,
+ compiled=opt.compiled if opt.compiled is not NULL else None,
+ val=opt.val if opt.val is not NULL else None,
+ label=opt.label if opt.label is not NULL else None,
+ dispchar=opt.dispchar if opt.dispchar is not NULL else None,
+ dispsize=opt.dispsize,
+ )
+ )
+ i += 1
+
+ return rv
diff --git a/psycopg_c/psycopg_c/pq/escaping.pyx b/psycopg_c/psycopg_c/pq/escaping.pyx
new file mode 100644
index 0000000..f0a44d3
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/escaping.pyx
@@ -0,0 +1,132 @@
+"""
+psycopg_c.pq.Escaping object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from libc.string cimport strlen
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+
+
+cdef class Escaping:
+ def __init__(self, PGconn conn = None):
+ self.conn = conn
+
+ cpdef escape_literal(self, data):
+ cdef char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ if self.conn is None:
+ raise e.OperationalError("escape_literal failed: no connection provided")
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ out = libpq.PQescapeLiteral(self.conn._pgconn_ptr, ptr, length)
+ if out is NULL:
+ raise e.OperationalError(
+ f"escape_literal failed: {error_message(self.conn)}"
+ )
+
+ rv = out[:strlen(out)]
+ libpq.PQfreemem(out)
+ return rv
+
+ cpdef escape_identifier(self, data):
+ cdef char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ if self.conn is None:
+ raise e.OperationalError("escape_identifier failed: no connection provided")
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ out = libpq.PQescapeIdentifier(self.conn._pgconn_ptr, ptr, length)
+ if out is NULL:
+ raise e.OperationalError(
+ f"escape_identifier failed: {error_message(self.conn)}"
+ )
+
+ rv = out[:strlen(out)]
+ libpq.PQfreemem(out)
+ return rv
+
+ cpdef escape_string(self, data):
+ cdef int error
+ cdef size_t len_out
+ cdef char *ptr
+ cdef char *buf_out
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ if self.conn is not None:
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ buf_out = <char *>PyMem_Malloc(length * 2 + 1)
+ len_out = libpq.PQescapeStringConn(
+ self.conn._pgconn_ptr, buf_out, ptr, length, &error
+ )
+ if error:
+ PyMem_Free(buf_out)
+ raise e.OperationalError(
+ f"escape_string failed: {error_message(self.conn)}"
+ )
+
+ else:
+ buf_out = <char *>PyMem_Malloc(length * 2 + 1)
+ len_out = libpq.PQescapeString(buf_out, ptr, length)
+
+ rv = buf_out[:len_out]
+ PyMem_Free(buf_out)
+ return rv
+
+ cpdef escape_bytea(self, data):
+ cdef size_t len_out
+ cdef unsigned char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ if self.conn is not None and self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ if self.conn is not None:
+ out = libpq.PQescapeByteaConn(
+ self.conn._pgconn_ptr, <unsigned char *>ptr, length, &len_out)
+ else:
+ out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out)
+
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {len(data)} bytes"
+ )
+
+ rv = out[:len_out - 1] # out includes final 0
+ libpq.PQfreemem(out)
+ return rv
+
+ cpdef unescape_bytea(self, const unsigned char *data):
+ # not needed, but let's keep it symmetric with the escaping:
+ # if a connection is passed in, it must be valid.
+ if self.conn is not None:
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ cdef size_t len_out
+ cdef unsigned char *out = libpq.PQunescapeBytea(data, &len_out)
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = out[:len_out]
+ libpq.PQfreemem(out)
+ return rv
diff --git a/psycopg_c/psycopg_c/pq/libpq.pxd b/psycopg_c/psycopg_c/pq/libpq.pxd
new file mode 100644
index 0000000..5e05e40
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/libpq.pxd
@@ -0,0 +1,321 @@
+"""
+Libpq header definition for the cython psycopg.pq implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cdef extern from "stdio.h":
+
+ ctypedef struct FILE:
+ pass
+
+cdef extern from "pg_config.h":
+
+ int PG_VERSION_NUM
+
+
+cdef extern from "libpq-fe.h":
+
+ # structures and types
+
+ ctypedef unsigned int Oid
+
+ ctypedef struct PGconn:
+ pass
+
+ ctypedef struct PGresult:
+ pass
+
+ ctypedef struct PQconninfoOption:
+ char *keyword
+ char *envvar
+ char *compiled
+ char *val
+ char *label
+ char *dispchar
+ int dispsize
+
+ ctypedef struct PGnotify:
+ char *relname
+ int be_pid
+ char *extra
+
+ ctypedef struct PGcancel:
+ pass
+
+ ctypedef struct PGresAttDesc:
+ char *name
+ Oid tableid
+ int columnid
+ int format
+ Oid typid
+ int typlen
+ int atttypmod
+
+ # enums
+
+ ctypedef enum PostgresPollingStatusType:
+ PGRES_POLLING_FAILED = 0
+ PGRES_POLLING_READING
+ PGRES_POLLING_WRITING
+ PGRES_POLLING_OK
+ PGRES_POLLING_ACTIVE
+
+
+ ctypedef enum PGPing:
+ PQPING_OK
+ PQPING_REJECT
+ PQPING_NO_RESPONSE
+ PQPING_NO_ATTEMPT
+
+ ctypedef enum ConnStatusType:
+ CONNECTION_OK
+ CONNECTION_BAD
+ CONNECTION_STARTED
+ CONNECTION_MADE
+ CONNECTION_AWAITING_RESPONSE
+ CONNECTION_AUTH_OK
+ CONNECTION_SETENV
+ CONNECTION_SSL_STARTUP
+ CONNECTION_NEEDED
+ CONNECTION_CHECK_WRITABLE
+ CONNECTION_GSS_STARTUP
+ # CONNECTION_CHECK_TARGET PG 12
+
+ ctypedef enum PGTransactionStatusType:
+ PQTRANS_IDLE
+ PQTRANS_ACTIVE
+ PQTRANS_INTRANS
+ PQTRANS_INERROR
+ PQTRANS_UNKNOWN
+
+ ctypedef enum ExecStatusType:
+ PGRES_EMPTY_QUERY = 0
+ PGRES_COMMAND_OK
+ PGRES_TUPLES_OK
+ PGRES_COPY_OUT
+ PGRES_COPY_IN
+ PGRES_BAD_RESPONSE
+ PGRES_NONFATAL_ERROR
+ PGRES_FATAL_ERROR
+ PGRES_COPY_BOTH
+ PGRES_SINGLE_TUPLE
+ PGRES_PIPELINE_SYNC
+ PGRES_PIPELINE_ABORT
+
+ # 33.1. Database Connection Control Functions
+ PGconn *PQconnectdb(const char *conninfo)
+ PGconn *PQconnectStart(const char *conninfo)
+ PostgresPollingStatusType PQconnectPoll(PGconn *conn) nogil
+ PQconninfoOption *PQconndefaults()
+ PQconninfoOption *PQconninfo(PGconn *conn)
+ PQconninfoOption *PQconninfoParse(const char *conninfo, char **errmsg)
+ void PQfinish(PGconn *conn)
+ void PQreset(PGconn *conn)
+ int PQresetStart(PGconn *conn)
+ PostgresPollingStatusType PQresetPoll(PGconn *conn)
+ PGPing PQping(const char *conninfo)
+
+ # 33.2. Connection Status Functions
+ char *PQdb(const PGconn *conn)
+ char *PQuser(const PGconn *conn)
+ char *PQpass(const PGconn *conn)
+ char *PQhost(const PGconn *conn)
+ char *PQhostaddr(const PGconn *conn)
+ char *PQport(const PGconn *conn)
+ char *PQtty(const PGconn *conn)
+ char *PQoptions(const PGconn *conn)
+ ConnStatusType PQstatus(const PGconn *conn)
+ PGTransactionStatusType PQtransactionStatus(const PGconn *conn)
+ const char *PQparameterStatus(const PGconn *conn, const char *paramName)
+ int PQprotocolVersion(const PGconn *conn)
+ int PQserverVersion(const PGconn *conn)
+ char *PQerrorMessage(const PGconn *conn)
+ int PQsocket(const PGconn *conn) nogil
+ int PQbackendPID(const PGconn *conn)
+ int PQconnectionNeedsPassword(const PGconn *conn)
+ int PQconnectionUsedPassword(const PGconn *conn)
+ int PQsslInUse(PGconn *conn) # TODO: const in PG 12 docs - verify/report
+ # TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
+
+ # 33.3. Command Execution Functions
+ PGresult *PQexec(PGconn *conn, const char *command) nogil
+ PGresult *PQexecParams(PGconn *conn,
+ const char *command,
+ int nParams,
+ const Oid *paramTypes,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ PGresult *PQprepare(PGconn *conn,
+ const char *stmtName,
+ const char *query,
+ int nParams,
+ const Oid *paramTypes) nogil
+ PGresult *PQexecPrepared(PGconn *conn,
+ const char *stmtName,
+ int nParams,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ PGresult *PQdescribePrepared(PGconn *conn, const char *stmtName) nogil
+ PGresult *PQdescribePortal(PGconn *conn, const char *portalName) nogil
+ ExecStatusType PQresultStatus(const PGresult *res) nogil
+ # PQresStatus: not needed, we have pretty enums
+ char *PQresultErrorMessage(const PGresult *res) nogil
+ # TODO: PQresultVerboseErrorMessage
+ char *PQresultErrorField(const PGresult *res, int fieldcode) nogil
+ void PQclear(PGresult *res) nogil
+
+ # 33.3.2. Retrieving Query Result Information
+ int PQntuples(const PGresult *res)
+ int PQnfields(const PGresult *res)
+ char *PQfname(const PGresult *res, int column_number)
+ int PQfnumber(const PGresult *res, const char *column_name)
+ Oid PQftable(const PGresult *res, int column_number)
+ int PQftablecol(const PGresult *res, int column_number)
+ int PQfformat(const PGresult *res, int column_number)
+ Oid PQftype(const PGresult *res, int column_number)
+ int PQfmod(const PGresult *res, int column_number)
+ int PQfsize(const PGresult *res, int column_number)
+ int PQbinaryTuples(const PGresult *res)
+ char *PQgetvalue(const PGresult *res, int row_number, int column_number)
+ int PQgetisnull(const PGresult *res, int row_number, int column_number)
+ int PQgetlength(const PGresult *res, int row_number, int column_number)
+ int PQnparams(const PGresult *res)
+ Oid PQparamtype(const PGresult *res, int param_number)
+ # PQprint: pretty useless
+
+ # 33.3.3. Retrieving Other Result Information
+ char *PQcmdStatus(PGresult *res)
+ char *PQcmdTuples(PGresult *res)
+ Oid PQoidValue(const PGresult *res)
+
+ # 33.3.4. Escaping Strings for Inclusion in SQL Commands
+ char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length)
+ char *PQescapeLiteral(PGconn *conn, const char *str, size_t length)
+ size_t PQescapeStringConn(PGconn *conn,
+ char *to, const char *from_, size_t length,
+ int *error)
+ size_t PQescapeString(char *to, const char *from_, size_t length)
+ unsigned char *PQescapeByteaConn(PGconn *conn,
+ const unsigned char *src,
+ size_t from_length,
+ size_t *to_length)
+ unsigned char *PQescapeBytea(const unsigned char *src,
+ size_t from_length,
+ size_t *to_length)
+ unsigned char *PQunescapeBytea(const unsigned char *src, size_t *to_length)
+
+
+ # 33.4. Asynchronous Command Processing
+ int PQsendQuery(PGconn *conn, const char *command) nogil
+ int PQsendQueryParams(PGconn *conn,
+ const char *command,
+ int nParams,
+ const Oid *paramTypes,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ int PQsendPrepare(PGconn *conn,
+ const char *stmtName,
+ const char *query,
+ int nParams,
+ const Oid *paramTypes) nogil
+ int PQsendQueryPrepared(PGconn *conn,
+ const char *stmtName,
+ int nParams,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ int PQsendDescribePrepared(PGconn *conn, const char *stmtName) nogil
+ int PQsendDescribePortal(PGconn *conn, const char *portalName) nogil
+ PGresult *PQgetResult(PGconn *conn) nogil
+ int PQconsumeInput(PGconn *conn) nogil
+ int PQisBusy(PGconn *conn) nogil
+ int PQsetnonblocking(PGconn *conn, int arg) nogil
+ int PQisnonblocking(const PGconn *conn)
+ int PQflush(PGconn *conn) nogil
+
+ # 33.5. Retrieving Query Results Row-by-Row
+ int PQsetSingleRowMode(PGconn *conn)
+
+ # 33.6. Canceling Queries in Progress
+ PGcancel *PQgetCancel(PGconn *conn)
+ void PQfreeCancel(PGcancel *cancel)
+ int PQcancel(PGcancel *cancel, char *errbuf, int errbufsize)
+
+ # 33.8. Asynchronous Notification
+ PGnotify *PQnotifies(PGconn *conn) nogil
+
+ # 33.9. Functions Associated with the COPY Command
+ int PQputCopyData(PGconn *conn, const char *buffer, int nbytes) nogil
+ int PQputCopyEnd(PGconn *conn, const char *errormsg) nogil
+ int PQgetCopyData(PGconn *conn, char **buffer, int async) nogil
+
+ # 33.10. Control Functions
+ void PQtrace(PGconn *conn, FILE *stream);
+ void PQsetTraceFlags(PGconn *conn, int flags);
+ void PQuntrace(PGconn *conn);
+
+ # 33.11. Miscellaneous Functions
+ void PQfreemem(void *ptr) nogil
+ void PQconninfoFree(PQconninfoOption *connOptions)
+ char *PQencryptPasswordConn(
+ PGconn *conn, const char *passwd, const char *user, const char *algorithm);
+ PGresult *PQmakeEmptyPGresult(PGconn *conn, ExecStatusType status)
+ int PQsetResultAttrs(PGresult *res, int numAttributes, PGresAttDesc *attDescs)
+ int PQlibVersion()
+
+ # 33.12. Notice Processing
+ ctypedef void (*PQnoticeReceiver)(void *arg, const PGresult *res)
+ PQnoticeReceiver PQsetNoticeReceiver(
+ PGconn *conn, PQnoticeReceiver prog, void *arg)
+
+ # 33.18. SSL Support
+ void PQinitOpenSSL(int do_ssl, int do_crypto)
+
+ # 34.5 Pipeline Mode
+
+ ctypedef enum PGpipelineStatus:
+ PQ_PIPELINE_OFF
+ PQ_PIPELINE_ON
+ PQ_PIPELINE_ABORTED
+
+ PGpipelineStatus PQpipelineStatus(const PGconn *conn)
+ int PQenterPipelineMode(PGconn *conn)
+ int PQexitPipelineMode(PGconn *conn)
+ int PQpipelineSync(PGconn *conn)
+ int PQsendFlushRequest(PGconn *conn)
+
+cdef extern from *:
+ """
+/* Hack to allow the use of old libpq versions */
+#if PG_VERSION_NUM < 100000
+#define PQencryptPasswordConn(conn, passwd, user, algorithm) NULL
+#endif
+
+#if PG_VERSION_NUM < 120000
+#define PQhostaddr(conn) NULL
+#endif
+
+#if PG_VERSION_NUM < 140000
+#define PGRES_PIPELINE_SYNC 10
+#define PGRES_PIPELINE_ABORTED 11
+typedef enum {
+ PQ_PIPELINE_OFF,
+ PQ_PIPELINE_ON,
+ PQ_PIPELINE_ABORTED
+} PGpipelineStatus;
+#define PQpipelineStatus(conn) PQ_PIPELINE_OFF
+#define PQenterPipelineMode(conn) 0
+#define PQexitPipelineMode(conn) 1
+#define PQpipelineSync(conn) 0
+#define PQsendFlushRequest(conn) 0
+#define PQsetTraceFlags(conn, stream) do {} while (0)
+#endif
+"""
diff --git a/psycopg_c/psycopg_c/pq/pgcancel.pyx b/psycopg_c/psycopg_c/pq/pgcancel.pyx
new file mode 100644
index 0000000..b7cbb70
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pgcancel.pyx
@@ -0,0 +1,32 @@
+"""
+psycopg_c.pq.PGcancel object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+
+cdef class PGcancel:
+ def __cinit__(self):
+ self.pgcancel_ptr = NULL
+
+ @staticmethod
+ cdef PGcancel _from_ptr(libpq.PGcancel *ptr):
+ cdef PGcancel rv = PGcancel.__new__(PGcancel)
+ rv.pgcancel_ptr = ptr
+ return rv
+
+ def __dealloc__(self) -> None:
+ self.free()
+
+ def free(self) -> None:
+ if self.pgcancel_ptr is not NULL:
+ libpq.PQfreeCancel(self.pgcancel_ptr)
+ self.pgcancel_ptr = NULL
+
+ def cancel(self) -> None:
+ cdef char buf[256]
+ cdef int res = libpq.PQcancel(self.pgcancel_ptr, buf, sizeof(buf))
+ if not res:
+ raise e.OperationalError(
+ f"cancel failed: {buf.decode('utf8', 'ignore')}"
+ )
diff --git a/psycopg_c/psycopg_c/pq/pgconn.pyx b/psycopg_c/psycopg_c/pq/pgconn.pyx
new file mode 100644
index 0000000..4a60530
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pgconn.pyx
@@ -0,0 +1,733 @@
+"""
+psycopg_c.pq.PGconn object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cdef extern from * nogil:
+ """
+#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS)
+ /* We don't need a real definition for this because Windows is not affected
+ * by the issue caused by closing the fds after fork.
+ */
+ #define getpid() (0)
+#else
+ #include <unistd.h>
+#endif
+ """
+ pid_t getpid()
+
+from libc.stdio cimport fdopen
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+from cpython.bytes cimport PyBytes_AsString
+from cpython.memoryview cimport PyMemoryView_FromObject
+
+import sys
+
+from psycopg.pq import Format as PqFormat, Trace
+from psycopg.pq.misc import PGnotify, connection_summary
+from psycopg_c.pq cimport PQBuffer
+
+
+cdef class PGconn:
+ @staticmethod
+ cdef PGconn _from_ptr(libpq.PGconn *ptr):
+ cdef PGconn rv = PGconn.__new__(PGconn)
+ rv._pgconn_ptr = ptr
+
+ libpq.PQsetNoticeReceiver(ptr, notice_receiver, <void *>rv)
+ return rv
+
+ def __cinit__(self):
+ self._pgconn_ptr = NULL
+ self._procpid = getpid()
+
+ def __dealloc__(self):
+ # Close the connection only if it was created in this process,
+ # not if this object is being GC'd after fork.
+ if self._procpid == getpid():
+ self.finish()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @classmethod
+ def connect(cls, const char *conninfo) -> PGconn:
+ cdef libpq.PGconn* pgconn = libpq.PQconnectdb(conninfo)
+ if not pgconn:
+ raise MemoryError("couldn't allocate PGconn")
+
+ return PGconn._from_ptr(pgconn)
+
+ @classmethod
+ def connect_start(cls, const char *conninfo) -> PGconn:
+ cdef libpq.PGconn* pgconn = libpq.PQconnectStart(conninfo)
+ if not pgconn:
+ raise MemoryError("couldn't allocate PGconn")
+
+ return PGconn._from_ptr(pgconn)
+
+ def connect_poll(self) -> int:
+ return _call_int(self, <conn_int_f>libpq.PQconnectPoll)
+
+ def finish(self) -> None:
+ if self._pgconn_ptr is not NULL:
+ libpq.PQfinish(self._pgconn_ptr)
+ self._pgconn_ptr = NULL
+
+ @property
+ def pgconn_ptr(self) -> Optional[int]:
+ if self._pgconn_ptr:
+ return <long long><void *>self._pgconn_ptr
+ else:
+ return None
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ _ensure_pgconn(self)
+ cdef libpq.PQconninfoOption *opts = libpq.PQconninfo(self._pgconn_ptr)
+ if opts is NULL:
+ raise MemoryError("couldn't allocate connection info")
+ rv = _options_from_array(opts)
+ libpq.PQconninfoFree(opts)
+ return rv
+
+ def reset(self) -> None:
+ _ensure_pgconn(self)
+ libpq.PQreset(self._pgconn_ptr)
+
+ def reset_start(self) -> None:
+ if not libpq.PQresetStart(self._pgconn_ptr):
+ raise e.OperationalError("couldn't reset connection")
+
+ def reset_poll(self) -> int:
+ return _call_int(self, <conn_int_f>libpq.PQresetPoll)
+
+ @classmethod
+ def ping(self, const char *conninfo) -> int:
+ return libpq.PQping(conninfo)
+
+ @property
+ def db(self) -> bytes:
+ return _call_bytes(self, libpq.PQdb)
+
+ @property
+ def user(self) -> bytes:
+ return _call_bytes(self, libpq.PQuser)
+
+ @property
+ def password(self) -> bytes:
+ return _call_bytes(self, libpq.PQpass)
+
+ @property
+ def host(self) -> bytes:
+ return _call_bytes(self, libpq.PQhost)
+
+ @property
+ def hostaddr(self) -> bytes:
+ if libpq.PG_VERSION_NUM < 120000:
+ raise e.NotSupportedError(
+ f"PQhostaddr requires libpq from PostgreSQL 12,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+
+ _ensure_pgconn(self)
+ cdef char *rv = libpq.PQhostaddr(self._pgconn_ptr)
+ assert rv is not NULL
+ return rv
+
+ @property
+ def port(self) -> bytes:
+ return _call_bytes(self, libpq.PQport)
+
+ @property
+ def tty(self) -> bytes:
+ return _call_bytes(self, libpq.PQtty)
+
+ @property
+ def options(self) -> bytes:
+ return _call_bytes(self, libpq.PQoptions)
+
+ @property
+ def status(self) -> int:
+ return libpq.PQstatus(self._pgconn_ptr)
+
+ @property
+ def transaction_status(self) -> int:
+ return libpq.PQtransactionStatus(self._pgconn_ptr)
+
+ def parameter_status(self, const char *name) -> Optional[bytes]:
+ _ensure_pgconn(self)
+ cdef const char *rv = libpq.PQparameterStatus(self._pgconn_ptr, name)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ @property
+ def error_message(self) -> bytes:
+ return libpq.PQerrorMessage(self._pgconn_ptr)
+
+ @property
+ def protocol_version(self) -> int:
+ return _call_int(self, libpq.PQprotocolVersion)
+
+ @property
+ def server_version(self) -> int:
+ return _call_int(self, libpq.PQserverVersion)
+
+ @property
+ def socket(self) -> int:
+ rv = _call_int(self, libpq.PQsocket)
+ if rv == -1:
+ raise e.OperationalError("the connection is lost")
+ return rv
+
+ @property
+ def backend_pid(self) -> int:
+ return _call_int(self, libpq.PQbackendPID)
+
+ @property
+ def needs_password(self) -> bool:
+ return bool(libpq.PQconnectionNeedsPassword(self._pgconn_ptr))
+
+ @property
+ def used_password(self) -> bool:
+ return bool(libpq.PQconnectionUsedPassword(self._pgconn_ptr))
+
+ @property
+ def ssl_in_use(self) -> bool:
+ return bool(_call_int(self, <conn_int_f>libpq.PQsslInUse))
+
+ def exec_(self, const char *command) -> PGresult:
+ _ensure_pgconn(self)
+ cdef libpq.PGresult *pgresult
+ with nogil:
+ pgresult = libpq.PQexec(self._pgconn_ptr, command)
+ if pgresult is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+
+ return PGresult._from_ptr(pgresult)
+
+ def send_query(self, const char *command) -> None:
+ _ensure_pgconn(self)
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendQuery(self._pgconn_ptr, command)
+ if not rv:
+ raise e.OperationalError(f"sending query failed: {error_message(self)}")
+
+ def exec_params(
+ self,
+ const char *command,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> PGresult:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, param_types, param_formats)
+
+ cdef libpq.PGresult *pgresult
+ with nogil:
+ pgresult = libpq.PQexecParams(
+ self._pgconn_ptr, command, <int>cnparams, ctypes,
+ <const char *const *>cvalues, clengths, cformats, result_format)
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if pgresult is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(pgresult)
+
+ def send_query_params(
+ self,
+ const char *command,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> None:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, param_types, param_formats)
+
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendQueryParams(
+ self._pgconn_ptr, command, <int>cnparams, ctypes,
+ <const char *const *>cvalues, clengths, cformats, result_format)
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if not rv:
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_prepare(
+ self,
+ const char *name,
+ const char *command,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ _ensure_pgconn(self)
+
+ cdef int i
+ cdef Py_ssize_t nparams = len(param_types) if param_types else 0
+ cdef libpq.Oid *atypes = NULL
+ if nparams:
+ atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid))
+ for i in range(nparams):
+ atypes[i] = param_types[i]
+
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendPrepare(
+ self._pgconn_ptr, name, command, <int>nparams, atypes
+ )
+ PyMem_Free(atypes)
+ if not rv:
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_query_prepared(
+ self,
+ const char *name,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> None:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, None, param_formats)
+
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendQueryPrepared(
+ self._pgconn_ptr, name, <int>cnparams, <const char *const *>cvalues,
+ clengths, cformats, result_format)
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if not rv:
+ raise e.OperationalError(
+ f"sending prepared query failed: {error_message(self)}"
+ )
+
+ def prepare(
+ self,
+ const char *name,
+ const char *command,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> PGresult:
+ _ensure_pgconn(self)
+
+ cdef int i
+ cdef Py_ssize_t nparams = len(param_types) if param_types else 0
+ cdef libpq.Oid *atypes = NULL
+ if nparams:
+ atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid))
+ for i in range(nparams):
+ atypes[i] = param_types[i]
+
+ cdef libpq.PGresult *rv
+ with nogil:
+ rv = libpq.PQprepare(
+ self._pgconn_ptr, name, command, <int>nparams, atypes)
+ PyMem_Free(atypes)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def exec_prepared(
+ self,
+ const char *name,
+ param_values: Optional[Sequence[bytes]],
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> PGresult:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, None, param_formats)
+
+ cdef libpq.PGresult *rv
+ with nogil:
+ rv = libpq.PQexecPrepared(
+ self._pgconn_ptr, name, <int>cnparams,
+ <const char *const *>cvalues,
+ clengths, cformats, result_format)
+
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def describe_prepared(self, const char *name) -> PGresult:
+ _ensure_pgconn(self)
+ cdef libpq.PGresult *rv = libpq.PQdescribePrepared(self._pgconn_ptr, name)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def send_describe_prepared(self, const char *name) -> None:
+ _ensure_pgconn(self)
+ cdef int rv = libpq.PQsendDescribePrepared(self._pgconn_ptr, name)
+ if not rv:
+ raise e.OperationalError(
+ f"sending describe prepared failed: {error_message(self)}"
+ )
+
+ def describe_portal(self, const char *name) -> PGresult:
+ _ensure_pgconn(self)
+ cdef libpq.PGresult *rv = libpq.PQdescribePortal(self._pgconn_ptr, name)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def send_describe_portal(self, const char *name) -> None:
+ _ensure_pgconn(self)
+ cdef int rv = libpq.PQsendDescribePortal(self._pgconn_ptr, name)
+ if not rv:
+ raise e.OperationalError(
+ f"sending describe prepared failed: {error_message(self)}"
+ )
+
+ def get_result(self) -> Optional["PGresult"]:
+ cdef libpq.PGresult *pgresult = libpq.PQgetResult(self._pgconn_ptr)
+ if pgresult is NULL:
+ return None
+ return PGresult._from_ptr(pgresult)
+
+ def consume_input(self) -> None:
+ if 1 != libpq.PQconsumeInput(self._pgconn_ptr):
+ raise e.OperationalError(f"consuming input failed: {error_message(self)}")
+
+ def is_busy(self) -> int:
+ cdef int rv
+ with nogil:
+ rv = libpq.PQisBusy(self._pgconn_ptr)
+ return rv
+
+ @property
+ def nonblocking(self) -> int:
+ return libpq.PQisnonblocking(self._pgconn_ptr)
+
+ @nonblocking.setter
+ def nonblocking(self, int arg) -> None:
+ if 0 > libpq.PQsetnonblocking(self._pgconn_ptr, arg):
+ raise e.OperationalError(f"setting nonblocking failed: {error_message(self)}")
+
+ cpdef int flush(self) except -1:
+ if self._pgconn_ptr == NULL:
+ raise e.OperationalError(f"flushing failed: the connection is closed")
+ cdef int rv = libpq.PQflush(self._pgconn_ptr)
+ if rv < 0:
+ raise e.OperationalError(f"flushing failed: {error_message(self)}")
+ return rv
+
+ def set_single_row_mode(self) -> None:
+ if not libpq.PQsetSingleRowMode(self._pgconn_ptr):
+ raise e.OperationalError("setting single row mode failed")
+
+ def get_cancel(self) -> PGcancel:
+ cdef libpq.PGcancel *ptr = libpq.PQgetCancel(self._pgconn_ptr)
+ if not ptr:
+ raise e.OperationalError("couldn't create cancel object")
+ return PGcancel._from_ptr(ptr)
+
+ cpdef object notifies(self):
+ cdef libpq.PGnotify *ptr
+ with nogil:
+ ptr = libpq.PQnotifies(self._pgconn_ptr)
+ if ptr:
+ ret = PGnotify(ptr.relname, ptr.be_pid, ptr.extra)
+ libpq.PQfreemem(ptr)
+ return ret
+ else:
+ return None
+
+ def put_copy_data(self, buffer) -> int:
+ cdef int rv
+ cdef char *cbuffer
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(buffer, &cbuffer, &length)
+ rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, <int>length)
+ if rv < 0:
+ raise e.OperationalError(f"sending copy data failed: {error_message(self)}")
+ return rv
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ cdef int rv
+ cdef const char *cerr = NULL
+ if error is not None:
+ cerr = PyBytes_AsString(error)
+ rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr)
+ if rv < 0:
+ raise e.OperationalError(f"sending copy end failed: {error_message(self)}")
+ return rv
+
+ def get_copy_data(self, int async_) -> Tuple[int, memoryview]:
+ cdef char *buffer_ptr = NULL
+ cdef int nbytes
+ nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_)
+ if nbytes == -2:
+ raise e.OperationalError(f"receiving copy data failed: {error_message(self)}")
+ if buffer_ptr is not NULL:
+ data = PyMemoryView_FromObject(
+ PQBuffer._from_buffer(<unsigned char *>buffer_ptr, nbytes))
+ return nbytes, data
+ else:
+ return nbytes, b"" # won't parse it, doesn't really be memoryview
+
+ def trace(self, fileno: int) -> None:
+ if sys.platform != "linux":
+ raise e.NotSupportedError("currently only supported on Linux")
+ stream = fdopen(fileno, b"w")
+ libpq.PQtrace(self._pgconn_ptr, stream)
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQsetTraceFlags requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ libpq.PQsetTraceFlags(self._pgconn_ptr, flags)
+
+ def untrace(self) -> None:
+ libpq.PQuntrace(self._pgconn_ptr)
+
+ def encrypt_password(
+ self, const char *passwd, const char *user, algorithm = None
+ ) -> bytes:
+ if libpq.PG_VERSION_NUM < 100000:
+ raise e.NotSupportedError(
+ f"PQencryptPasswordConn requires libpq from PostgreSQL 10,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+
+ cdef char *out
+ cdef const char *calgo = NULL
+ if algorithm:
+ calgo = algorithm
+ out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo)
+ if not out:
+ raise e.OperationalError(
+ f"password encryption failed: {error_message(self)}"
+ )
+
+ rv = bytes(out)
+ libpq.PQfreemem(out)
+ return rv
+
+ def make_empty_result(self, int exec_status) -> PGresult:
+ cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult(
+ self._pgconn_ptr, <libpq.ExecStatusType>exec_status)
+ if not rv:
+ raise MemoryError("couldn't allocate empty PGresult")
+ return PGresult._from_ptr(rv)
+
+ @property
+ def pipeline_status(self) -> int:
+ """The current pipeline mode status.
+
+ For libpq < 14.0, always return 0 (PQ_PIPELINE_OFF).
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ return libpq.PQ_PIPELINE_OFF
+ cdef int status = libpq.PQpipelineStatus(self._pgconn_ptr)
+ return status
+
+ def enter_pipeline_mode(self) -> None:
+ """Enter pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to enter the pipeline
+ mode.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQenterPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ if libpq.PQenterPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError("failed to enter pipeline mode")
+
+ def exit_pipeline_mode(self) -> None:
+ """Exit pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to exit the pipeline
+ mode.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQexitPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ if libpq.PQexitPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError(error_message(self))
+
+ def pipeline_sync(self) -> None:
+ """Mark a synchronization point in a pipeline.
+
+ :raises ~e.OperationalError: if the connection is not in pipeline mode
+ or if sync failed.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQpipelineSync requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ rv = libpq.PQpipelineSync(self._pgconn_ptr)
+ if rv == 0:
+ raise e.OperationalError("connection not in pipeline mode")
+ if rv != 1:
+ raise e.OperationalError("failed to sync pipeline")
+
+ def send_flush_request(self) -> None:
+ """Sends a request for the server to flush its output buffer.
+
+ :raises ~e.OperationalError: if the flush request failed.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQsendFlushRequest requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ cdef int rv = libpq.PQsendFlushRequest(self._pgconn_ptr)
+ if rv == 0:
+ raise e.OperationalError(f"flush request failed: {error_message(self)}")
+
+
+cdef int _ensure_pgconn(PGconn pgconn) except 0:
+ if pgconn._pgconn_ptr is not NULL:
+ return 1
+
+ raise e.OperationalError("the connection is closed")
+
+
+cdef char *_call_bytes(PGconn pgconn, conn_bytes_f func) except NULL:
+ """
+ Call one of the pgconn libpq functions returning a bytes pointer.
+ """
+ if not _ensure_pgconn(pgconn):
+ return NULL
+ cdef char *rv = func(pgconn._pgconn_ptr)
+ assert rv is not NULL
+ return rv
+
+
+cdef int _call_int(PGconn pgconn, conn_int_f func) except -2:
+ """
+ Call one of the pgconn libpq functions returning an int.
+ """
+ if not _ensure_pgconn(pgconn):
+ return -2
+ return func(pgconn._pgconn_ptr)
+
+
+cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) with gil:
+ cdef PGconn pgconn = <object>arg
+ if pgconn.notice_handler is None:
+ return
+
+ cdef PGresult res = PGresult._from_ptr(<libpq.PGresult *>res_ptr)
+ try:
+ pgconn.notice_handler(res)
+ except Exception as e:
+ logger.exception("error in notice receiver: %s", e)
+ finally:
+ res._pgresult_ptr = NULL # avoid destroying the pgresult_ptr
+
+
+cdef (Py_ssize_t, libpq.Oid *, char * const*, int *, int *) _query_params_args(
+ list param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]],
+ list param_formats: Optional[Sequence[int]],
+) except *:
+ cdef int i
+
+ # the PostgresQuery converts the param_types to tuple, so this operation
+ # is most often no-op
+ cdef tuple tparam_types
+ if param_types is not None and not isinstance(param_types, tuple):
+ tparam_types = tuple(param_types)
+ else:
+ tparam_types = param_types
+
+ cdef Py_ssize_t nparams = len(param_values) if param_values else 0
+ if tparam_types is not None and len(tparam_types) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(tparam_types))
+ )
+ if param_formats is not None and len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_formats"
+ % (nparams, len(param_formats))
+ )
+
+ cdef char **aparams = NULL
+ cdef int *alenghts = NULL
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ if nparams:
+ aparams = <char **>PyMem_Malloc(nparams * sizeof(char *))
+ alenghts = <int *>PyMem_Malloc(nparams * sizeof(int))
+ for i in range(nparams):
+ obj = param_values[i]
+ if obj is None:
+ aparams[i] = NULL
+ alenghts[i] = 0
+ else:
+ # TODO: it is a leak if this fails (but it should only fail
+ # on internal error, e.g. if obj is not a buffer)
+ _buffer_as_string_and_size(obj, &ptr, &length)
+ aparams[i] = ptr
+ alenghts[i] = <int>length
+
+ cdef libpq.Oid *atypes = NULL
+ if tparam_types:
+ atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid))
+ for i in range(nparams):
+ atypes[i] = tparam_types[i]
+
+ cdef int *aformats = NULL
+ if param_formats is not None:
+ aformats = <int *>PyMem_Malloc(nparams * sizeof(int *))
+ for i in range(nparams):
+ aformats[i] = param_formats[i]
+
+ return (nparams, atypes, aparams, alenghts, aformats)
+
+
+cdef void _clear_query_params(
+ libpq.Oid *ctypes, char *const *cvalues, int *clenghst, int *cformats
+):
+ PyMem_Free(ctypes)
+ PyMem_Free(<char **>cvalues)
+ PyMem_Free(clenghst)
+ PyMem_Free(cformats)
diff --git a/psycopg_c/psycopg_c/pq/pgresult.pyx b/psycopg_c/psycopg_c/pq/pgresult.pyx
new file mode 100644
index 0000000..6df42e8
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pgresult.pyx
@@ -0,0 +1,157 @@
+"""
+psycopg_c.pq.PGresult object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+
+from psycopg.pq.misc import PGresAttDesc
+from psycopg.pq._enums import ExecStatus
+
+
+@cython.freelist(8)
+cdef class PGresult:
+ def __cinit__(self):
+ self._pgresult_ptr = NULL
+
+ @staticmethod
+ cdef PGresult _from_ptr(libpq.PGresult *ptr):
+ cdef PGresult rv = PGresult.__new__(PGresult)
+ rv._pgresult_ptr = ptr
+ return rv
+
+ def __dealloc__(self) -> None:
+ self.clear()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ status = ExecStatus(self.status)
+ return f"<{cls} [{status.name}] at 0x{id(self):x}>"
+
+ def clear(self) -> None:
+ if self._pgresult_ptr is not NULL:
+ libpq.PQclear(self._pgresult_ptr)
+ self._pgresult_ptr = NULL
+
+ @property
+ def pgresult_ptr(self) -> Optional[int]:
+ if self._pgresult_ptr:
+ return <long long><void *>self._pgresult_ptr
+ else:
+ return None
+
+ @property
+ def status(self) -> int:
+ return libpq.PQresultStatus(self._pgresult_ptr)
+
+ @property
+ def error_message(self) -> bytes:
+ return libpq.PQresultErrorMessage(self._pgresult_ptr)
+
+ def error_field(self, int fieldcode) -> Optional[bytes]:
+ cdef char * rv = libpq.PQresultErrorField(self._pgresult_ptr, fieldcode)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ @property
+ def ntuples(self) -> int:
+ return libpq.PQntuples(self._pgresult_ptr)
+
+ @property
+ def nfields(self) -> int:
+ return libpq.PQnfields(self._pgresult_ptr)
+
+ def fname(self, int column_number) -> Optional[bytes]:
+ cdef char *rv = libpq.PQfname(self._pgresult_ptr, column_number)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ def ftable(self, int column_number) -> int:
+ return libpq.PQftable(self._pgresult_ptr, column_number)
+
+ def ftablecol(self, int column_number) -> int:
+ return libpq.PQftablecol(self._pgresult_ptr, column_number)
+
+ def fformat(self, int column_number) -> int:
+ return libpq.PQfformat(self._pgresult_ptr, column_number)
+
+ def ftype(self, int column_number) -> int:
+ return libpq.PQftype(self._pgresult_ptr, column_number)
+
+ def fmod(self, int column_number) -> int:
+ return libpq.PQfmod(self._pgresult_ptr, column_number)
+
+ def fsize(self, int column_number) -> int:
+ return libpq.PQfsize(self._pgresult_ptr, column_number)
+
+ @property
+ def binary_tuples(self) -> int:
+ return libpq.PQbinaryTuples(self._pgresult_ptr)
+
+ def get_value(self, int row_number, int column_number) -> Optional[bytes]:
+ cdef int crow = row_number
+ cdef int ccol = column_number
+ cdef int length = libpq.PQgetlength(self._pgresult_ptr, crow, ccol)
+ cdef char *v
+ if length:
+ v = libpq.PQgetvalue(self._pgresult_ptr, crow, ccol)
+ # TODO: avoid copy
+ return v[:length]
+ else:
+ if libpq.PQgetisnull(self._pgresult_ptr, crow, ccol):
+ return None
+ else:
+ return b""
+
+ @property
+ def nparams(self) -> int:
+ return libpq.PQnparams(self._pgresult_ptr)
+
+ def param_type(self, int param_number) -> int:
+ return libpq.PQparamtype(self._pgresult_ptr, param_number)
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ cdef char *rv = libpq.PQcmdStatus(self._pgresult_ptr)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ cdef char *rv = libpq.PQcmdTuples(self._pgresult_ptr)
+ if rv is NULL:
+ return None
+ cdef bytes brv = rv
+ return int(brv) if brv else None
+
+ @property
+ def oid_value(self) -> int:
+ return libpq.PQoidValue(self._pgresult_ptr)
+
+ def set_attributes(self, descriptions: List[PGresAttDesc]):
+ cdef Py_ssize_t num = len(descriptions)
+ cdef libpq.PGresAttDesc *attrs = <libpq.PGresAttDesc *>PyMem_Malloc(
+ num * sizeof(libpq.PGresAttDesc))
+
+ for i in range(num):
+ descr = descriptions[i]
+ attrs[i].name = descr.name
+ attrs[i].tableid = descr.tableid
+ attrs[i].columnid = descr.columnid
+ attrs[i].format = descr.format
+ attrs[i].typid = descr.typid
+ attrs[i].typlen = descr.typlen
+ attrs[i].atttypmod = descr.atttypmod
+
+ cdef int res = libpq.PQsetResultAttrs(self._pgresult_ptr, <int>num, attrs)
+ PyMem_Free(attrs)
+ if (res == 0):
+ raise e.OperationalError("PQsetResultAttrs failed")
diff --git a/psycopg_c/psycopg_c/pq/pqbuffer.pyx b/psycopg_c/psycopg_c/pq/pqbuffer.pyx
new file mode 100644
index 0000000..eb5d648
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pqbuffer.pyx
@@ -0,0 +1,111 @@
+"""
+PQbuffer object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+from cpython.bytes cimport PyBytes_AsStringAndSize
+from cpython.buffer cimport PyObject_CheckBuffer, PyBUF_SIMPLE
+from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release
+
+
+@cython.freelist(32)
+cdef class PQBuffer:
+ """
+ Wrap a chunk of memory allocated by the libpq and expose it as memoryview.
+ """
+ @staticmethod
+ cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length):
+ cdef PQBuffer rv = PQBuffer.__new__(PQBuffer)
+ rv.buf = buf
+ rv.len = length
+ return rv
+
+ def __cinit__(self):
+ self.buf = NULL
+ self.len = 0
+
+ def __dealloc__(self):
+ if self.buf:
+ libpq.PQfreemem(self.buf)
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ f"({bytes(self)})"
+ )
+
+ def __getbuffer__(self, Py_buffer *buffer, int flags):
+ buffer.buf = self.buf
+ buffer.obj = self
+ buffer.len = self.len
+ buffer.itemsize = sizeof(unsigned char)
+ buffer.readonly = 1
+ buffer.ndim = 1
+ buffer.format = NULL # unsigned char
+ buffer.shape = &self.len
+ buffer.strides = NULL
+ buffer.suboffsets = NULL
+ buffer.internal = NULL
+
+ def __releasebuffer__(self, Py_buffer *buffer):
+ pass
+
+
+@cython.freelist(32)
+cdef class ViewBuffer:
+ """
+ Wrap a chunk of memory owned by a different object.
+ """
+ @staticmethod
+ cdef ViewBuffer _from_buffer(
+ object obj, unsigned char *buf, Py_ssize_t length
+ ):
+ cdef ViewBuffer rv = ViewBuffer.__new__(ViewBuffer)
+ rv.obj = obj
+ rv.buf = buf
+ rv.len = length
+ return rv
+
+ def __cinit__(self):
+ self.buf = NULL
+ self.len = 0
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ f"({bytes(self)})"
+ )
+
+ def __getbuffer__(self, Py_buffer *buffer, int flags):
+ buffer.buf = self.buf
+ buffer.obj = self
+ buffer.len = self.len
+ buffer.itemsize = sizeof(unsigned char)
+ buffer.readonly = 1
+ buffer.ndim = 1
+ buffer.format = NULL # unsigned char
+ buffer.shape = &self.len
+ buffer.strides = NULL
+ buffer.suboffsets = NULL
+ buffer.internal = NULL
+
+ def __releasebuffer__(self, Py_buffer *buffer):
+ pass
+
+
+cdef int _buffer_as_string_and_size(
+ data: "Buffer", char **ptr, Py_ssize_t *length
+) except -1:
+ cdef Py_buffer buf
+
+ if isinstance(data, bytes):
+ PyBytes_AsStringAndSize(data, ptr, length)
+ elif PyObject_CheckBuffer(data):
+ PyObject_GetBuffer(data, &buf, PyBUF_SIMPLE)
+ ptr[0] = <char *>buf.buf
+ length[0] = buf.len
+ PyBuffer_Release(&buf)
+ else:
+ raise TypeError(f"bytes or buffer expected, got {type(data)}")
diff --git a/psycopg_c/psycopg_c/py.typed b/psycopg_c/psycopg_c/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/psycopg_c/psycopg_c/py.typed
diff --git a/psycopg_c/psycopg_c/types/array.pyx b/psycopg_c/psycopg_c/types/array.pyx
new file mode 100644
index 0000000..9abaef9
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/array.pyx
@@ -0,0 +1,276 @@
+"""
+C optimized functions to manipulate arrays
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import cython
+
+from libc.stdint cimport int32_t, uint32_t
+from libc.string cimport memset, strchr
+from cpython.mem cimport PyMem_Realloc, PyMem_Free
+from cpython.ref cimport Py_INCREF
+from cpython.list cimport PyList_New,PyList_Append, PyList_GetSlice
+from cpython.list cimport PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE
+from cpython.object cimport PyObject
+
+from psycopg_c.pq cimport _buffer_as_string_and_size
+from psycopg_c.pq.libpq cimport Oid
+from psycopg_c._psycopg cimport endian
+
+from psycopg import errors as e
+
+cdef extern from *:
+ """
+/* Defined in PostgreSQL in src/include/utils/array.h */
+#define MAXDIM 6
+ """
+ const int MAXDIM
+
+
+cdef class ArrayLoader(_CRecursiveLoader):
+
+ format = PQ_TEXT
+ base_oid = 0
+ delimiter = b","
+
+ cdef PyObject *row_loader
+ cdef char cdelim
+
+ # A memory area which used to unescape elements.
+ # Keep it here to avoid a malloc per element and to set up exceptions
+ # to make sure to free it on error.
+ cdef char *scratch
+ cdef size_t sclen
+
+ cdef object cload(self, const char *data, size_t length):
+ if self.cdelim == b"\x00":
+ self.row_loader = self._tx._c_get_loader(
+ <PyObject *>self.base_oid, <PyObject *>PQ_TEXT)
+ self.cdelim = self.delimiter[0]
+
+ return _array_load_text(
+ data, length, self.row_loader, self.cdelim,
+ &(self.scratch), &(self.sclen))
+
+ def __dealloc__(self):
+ PyMem_Free(self.scratch)
+
+
+@cython.final
+cdef class ArrayBinaryLoader(_CRecursiveLoader):
+
+ format = PQ_BINARY
+
+ cdef PyObject *row_loader
+
+ cdef object cload(self, const char *data, size_t length):
+ rv = _array_load_binary(data, length, self._tx, &(self.row_loader))
+ return rv
+
+
+cdef object _array_load_text(
+ const char *buf, size_t length, PyObject *row_loader, char cdelim,
+ char **scratch, size_t *sclen
+):
+ if length == 0:
+ raise e.DataError("malformed array: empty data")
+
+ cdef const char *end = buf + length
+
+ # Remove the dimensions information prefix (``[...]=``)
+ if buf[0] == b"[":
+ buf = strchr(buf + 1, b'=')
+ if buf == NULL:
+ raise e.DataError("malformed array: no '=' after dimension information")
+ buf += 1
+
+ # TODO: further optimization: pre-scan the array to find the array
+ # dimensions, so that we can preallocate the list sized instead of calling
+ # append, which is the dominating operation
+
+ cdef list stack = []
+ cdef list a = []
+ rv = a
+ cdef PyObject *tmp
+
+ cdef CLoader cloader = None
+ cdef object pyload = None
+ if (<RowLoader>row_loader).cloader is not None:
+ cloader = (<RowLoader>row_loader).cloader
+ else:
+ pyload = (<RowLoader>row_loader).loadfunc
+
+ while buf < end:
+ if buf[0] == b'{':
+ if stack:
+ tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1)
+ PyList_Append(<object>tmp, a)
+ PyList_Append(stack, a)
+ a = []
+ buf += 1
+
+ elif buf[0] == b'}':
+ if not stack:
+ raise e.DataError("malformed array: unexpected '}'")
+ rv = stack.pop()
+ buf += 1
+
+ elif buf[0] == cdelim:
+ buf += 1
+
+ else:
+ v = _parse_token(
+ &buf, end, cdelim, scratch, sclen, cloader, pyload)
+ if not stack:
+ raise e.DataError("malformed array: missing initial '{'")
+ tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1)
+ PyList_Append(<object>tmp, v)
+
+ return rv
+
+
+cdef object _parse_token(
+ const char **bufptr, const char *bufend, char cdelim,
+ char **scratch, size_t *sclen, CLoader cloader, object load
+):
+ cdef const char *start = bufptr[0]
+ cdef int has_quotes = start[0] == b'"'
+ cdef int quoted = has_quotes
+ cdef int num_escapes = 0
+ cdef int escaped = 0
+
+ if has_quotes:
+ start += 1
+ cdef const char *end = start
+
+ while end < bufend:
+ if (end[0] == cdelim or end[0] == b'}') and not quoted:
+ break
+ elif end[0] == b'\\' and not escaped:
+ num_escapes += 1
+ escaped = 1
+ end += 1
+ continue
+ elif end[0] == b'"' and not escaped:
+ quoted = 0
+ escaped = 0
+ end += 1
+ else:
+ raise e.DataError("malformed array: hit the end of the buffer")
+
+ # Return the new position for the buffer
+ bufptr[0] = end
+ if has_quotes:
+ end -= 1
+
+ cdef int length = (end - start)
+ if length == 4 and not has_quotes \
+ and start[0] == b'N' and start[1] == b'U' \
+ and start[2] == b'L' and start[3] == b'L':
+ return None
+
+ cdef const char *src
+ cdef char *tgt
+ cdef size_t unesclen
+
+ if not num_escapes:
+ if cloader is not None:
+ return cloader.cload(start, length)
+ else:
+ b = start[:length]
+ return load(b)
+
+ else:
+ unesclen = length - num_escapes + 1
+ if unesclen > sclen[0]:
+ scratch[0] = <char *>PyMem_Realloc(scratch[0], unesclen)
+ sclen[0] = unesclen
+
+ src = start
+ tgt = scratch[0]
+ while src < end:
+ if src[0] == b'\\':
+ src += 1
+ tgt[0] = src[0]
+ src += 1
+ tgt += 1
+
+ tgt[0] = b'\x00'
+
+ if cloader is not None:
+ return cloader.cload(scratch[0], length - num_escapes)
+ else:
+ b = scratch[0][:length - num_escapes]
+ return load(b)
+
+
+@cython.cdivision(True)
+cdef object _array_load_binary(
+ const char *buf, size_t length, Transformer tx, PyObject **row_loader_ptr
+):
+ # head is ndims, hasnull, elem oid
+ cdef uint32_t *buf32 = <uint32_t *>buf
+ cdef int ndims = endian.be32toh(buf32[0])
+
+ if ndims <= 0:
+ return []
+ elif ndims > MAXDIM:
+ raise e.DataError(
+ r"unexpected number of dimensions %s exceeding the maximum allowed %s"
+ % (ndims, MAXDIM)
+ )
+
+ cdef object oid
+ if row_loader_ptr[0] == NULL:
+ oid = <Oid>endian.be32toh(buf32[2])
+ row_loader_ptr[0] = tx._c_get_loader(<PyObject *>oid, <PyObject *>PQ_BINARY)
+
+ cdef Py_ssize_t[MAXDIM] dims
+ cdef int i
+ for i in range(ndims):
+ # Every dimension is dim, lower bound
+ dims[i] = endian.be32toh(buf32[3 + 2 * i])
+
+ buf += (3 + 2 * ndims) * sizeof(uint32_t)
+ out = _array_load_binary_rec(ndims, dims, &buf, row_loader_ptr[0])
+ return out
+
+
+cdef object _array_load_binary_rec(
+ Py_ssize_t ndims, Py_ssize_t *dims, const char **bufptr, PyObject *row_loader
+):
+ cdef const char *buf
+ cdef int i
+ cdef int32_t size
+ cdef object val
+
+ cdef Py_ssize_t nelems = dims[0]
+ cdef list out = PyList_New(nelems)
+
+ if ndims == 1:
+ buf = bufptr[0]
+ for i in range(nelems):
+ size = <int32_t>endian.be32toh((<uint32_t *>buf)[0])
+ buf += sizeof(uint32_t)
+ if size == -1:
+ val = None
+ else:
+ if (<RowLoader>row_loader).cloader is not None:
+ val = (<RowLoader>row_loader).cloader.cload(buf, size)
+ else:
+ val = (<RowLoader>row_loader).loadfunc(buf[:size])
+ buf += size
+
+ Py_INCREF(val)
+ PyList_SET_ITEM(out, i, val)
+
+ bufptr[0] = buf
+
+ else:
+ for i in range(nelems):
+ val = _array_load_binary_rec(ndims - 1, dims + 1, bufptr, row_loader)
+ Py_INCREF(val)
+ PyList_SET_ITEM(out, i, val)
+
+ return out
diff --git a/psycopg_c/psycopg_c/types/bool.pyx b/psycopg_c/psycopg_c/types/bool.pyx
new file mode 100644
index 0000000..86cf88e
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/bool.pyx
@@ -0,0 +1,78 @@
+"""
+Cython adapters for boolean.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+
+
+@cython.final
+cdef class BoolDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.BOOL_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *buf = CDumper.ensure_size(rv, offset, 1)
+
+ # Fast paths, just a pointer comparison
+ if obj is True:
+ buf[0] = b"t"
+ elif obj is False:
+ buf[0] = b"f"
+ elif obj:
+ buf[0] = b"t"
+ else:
+ buf[0] = b"f"
+
+ return 1
+
+ def quote(self, obj: bool) -> bytes:
+ if obj is True:
+ return b"true"
+ elif obj is False:
+ return b"false"
+ else:
+ return b"true" if obj else b"false"
+
+
+@cython.final
+cdef class BoolBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.BOOL_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *buf = CDumper.ensure_size(rv, offset, 1)
+
+ # Fast paths, just a pointer comparison
+ if obj is True:
+ buf[0] = b"\x01"
+ elif obj is False:
+ buf[0] = b"\x00"
+ elif obj:
+ buf[0] = b"\x01"
+ else:
+ buf[0] = b"\x00"
+
+ return 1
+
+
+@cython.final
+cdef class BoolLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ # this creates better C than `return data[0] == b't'`
+ return True if data[0] == b't' else False
+
+
+@cython.final
+cdef class BoolBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return True if data[0] else False
diff --git a/psycopg_c/psycopg_c/types/datetime.pyx b/psycopg_c/psycopg_c/types/datetime.pyx
new file mode 100644
index 0000000..51e7dcf
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/datetime.pyx
@@ -0,0 +1,1136 @@
+"""
+Cython adapters for date/time types.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from libc.string cimport memset, strchr
+from cpython cimport datetime as cdt
+from cpython.dict cimport PyDict_GetItem
+from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
+
+cdef extern from "Python.h":
+ const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL
+ object PyTimeZone_FromOffset(object offset)
+
+cdef extern from *:
+ """
+/* Multipliers from fraction of seconds to microseconds */
+static int _uspad[] = {0, 100000, 10000, 1000, 100, 10, 1};
+ """
+ cdef int *_uspad
+
+from datetime import date, time, timedelta, datetime, timezone
+
+from psycopg_c._psycopg cimport endian
+
+from psycopg import errors as e
+from psycopg._compat import ZoneInfo
+
+
+# Initialise the datetime C API
+cdt.import_datetime()
+
+cdef enum:
+ ORDER_YMD = 0
+ ORDER_DMY = 1
+ ORDER_MDY = 2
+ ORDER_PGDM = 3
+ ORDER_PGMD = 4
+
+cdef enum:
+ INTERVALSTYLE_OTHERS = 0
+ INTERVALSTYLE_SQL_STANDARD = 1
+ INTERVALSTYLE_POSTGRES = 2
+
+cdef enum:
+ PG_DATE_EPOCH_DAYS = 730120 # date(2000, 1, 1).toordinal()
+ PY_DATE_MIN_DAYS = 1 # date.min.toordinal()
+
+cdef object date_toordinal = date.toordinal
+cdef object date_fromordinal = date.fromordinal
+cdef object datetime_astimezone = datetime.astimezone
+cdef object time_utcoffset = time.utcoffset
+cdef object timedelta_total_seconds = timedelta.total_seconds
+cdef object timezone_utc = timezone.utc
+cdef object pg_datetime_epoch = datetime(2000, 1, 1)
+cdef object pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc)
+
+cdef object _month_abbr = {
+ n: i
+ for i, n in enumerate(
+ b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1
+ )
+}
+
+
+@cython.final
+cdef class DateDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.DATE_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ cdef str s = str(obj)
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class DateBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.DATE_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int32_t days = PyObject_CallFunctionObjArgs(
+ date_toordinal, <PyObject *>obj, NULL)
+ days -= PG_DATE_EPOCH_DAYS
+ cdef int32_t *buf = <int32_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int32_t))
+ buf[0] = endian.htobe32(days)
+ return sizeof(int32_t)
+
+
+cdef class _BaseTimeDumper(CDumper):
+
+ cpdef get_key(self, obj, format):
+ # Use (cls,) to report the need to upgrade to a dumper for timetz (the
+ # Frankenstein of the data types).
+ if not obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ cpdef upgrade(self, obj: time, format):
+ raise NotImplementedError
+
+
+cdef class _BaseTimeTextDumper(_BaseTimeDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ cdef str s = str(obj)
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class TimeDumper(_BaseTimeTextDumper):
+
+ oid = oids.TIME_OID
+
+ cpdef upgrade(self, obj, format):
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzDumper(self.cls)
+
+
+@cython.final
+cdef class TimeTzDumper(_BaseTimeTextDumper):
+
+ oid = oids.TIMETZ_OID
+
+
+@cython.final
+cdef class TimeBinaryDumper(_BaseTimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIME_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t micros = cdt.time_microsecond(obj) + 1000000 * (
+ cdt.time_second(obj)
+ + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj))
+ )
+
+ cdef int64_t *buf = <int64_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int64_t))
+ buf[0] = endian.htobe64(micros)
+ return sizeof(int64_t)
+
+ cpdef upgrade(self, obj, format):
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzBinaryDumper(self.cls)
+
+
+@cython.final
+cdef class TimeTzBinaryDumper(_BaseTimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIMETZ_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t micros = cdt.time_microsecond(obj) + 1_000_000 * (
+ cdt.time_second(obj)
+ + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj))
+ )
+
+ off = PyObject_CallFunctionObjArgs(time_utcoffset, <PyObject *>obj, NULL)
+ cdef int32_t offsec = int(PyObject_CallFunctionObjArgs(
+ timedelta_total_seconds, <PyObject *>off, NULL))
+
+ cdef char *buf = CDumper.ensure_size(
+ rv, offset, sizeof(int64_t) + sizeof(int32_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(-offsec)
+
+ return sizeof(int64_t) + sizeof(int32_t)
+
+
+cdef class _BaseDatetimeDumper(CDumper):
+
+ cpdef get_key(self, obj, format):
+ # Use (cls,) to report the need to upgrade (downgrade, actually) to a
+ # dumper for naive timestamp.
+ if obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ cpdef upgrade(self, obj: time, format):
+ raise NotImplementedError
+
+
+cdef class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ cdef str s = str(obj)
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class DatetimeDumper(_BaseDatetimeTextDumper):
+
+ oid = oids.TIMESTAMPTZ_OID
+
+ cpdef upgrade(self, obj, format):
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzDumper(self.cls)
+
+
+@cython.final
+cdef class DatetimeNoTzDumper(_BaseDatetimeTextDumper):
+
+ oid = oids.TIMESTAMP_OID
+
+
+@cython.final
+cdef class DatetimeBinaryDumper(_BaseDatetimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIMESTAMPTZ_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ delta = obj - pg_datetimetz_epoch
+
+ cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * (
+ 86_400 * <int64_t>cdt.timedelta_days(delta)
+ + <int64_t>cdt.timedelta_seconds(delta))
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ return sizeof(int64_t)
+
+ cpdef upgrade(self, obj, format):
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzBinaryDumper(self.cls)
+
+
+@cython.final
+cdef class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIMESTAMP_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ delta = obj - pg_datetime_epoch
+
+ cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * (
+ 86_400 * <int64_t>cdt.timedelta_days(delta)
+ + <int64_t>cdt.timedelta_seconds(delta))
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ return sizeof(int64_t)
+
+
+@cython.final
+cdef class TimedeltaDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.INTERVAL_OID
+ cdef int _style
+
+ def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_intervalstyle(self._pgconn)
+ if ds[0] == b's': # sql_standard
+ self._style = INTERVALSTYLE_SQL_STANDARD
+ else: # iso_8601, postgres, postgres_verbose
+ self._style = INTERVALSTYLE_OTHERS
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ cdef str s
+ if self._style == INTERVALSTYLE_OTHERS:
+ # The comma is parsed ok by PostgreSQL but it's not documented
+ # and it seems brittle to rely on it. CRDB doesn't consume it well.
+ s = str(obj).replace(",", "")
+ else:
+ # sql_standard format needs explicit signs
+ # otherwise -1 day 1 sec will mean -1 sec
+ s = "%+d day %+d second %+d microsecond" % (
+ obj.days, obj.seconds, obj.microseconds)
+
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class TimedeltaBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INTERVAL_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t micros = (
+ 1_000_000 * <int64_t>cdt.timedelta_seconds(obj)
+ + cdt.timedelta_microseconds(obj))
+ cdef int32_t days = cdt.timedelta_days(obj)
+
+ cdef char *buf = CDumper.ensure_size(
+ rv, offset, sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(days)
+ (<int32_t *>(buf + sizeof(int64_t) + sizeof(int32_t)))[0] = 0
+
+ return sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t)
+
+
+@cython.final
+cdef class DateLoader(CLoader):
+
+ format = PQ_TEXT
+ cdef int _order
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_datestyle(self._pgconn)
+ if ds[0] == b'I': # ISO
+ self._order = ORDER_YMD
+ elif ds[0] == b'G': # German
+ self._order = ORDER_DMY
+ elif ds[0] == b'S': # SQL, DMY / MDY
+ self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY
+ elif ds[0] == b'P': # Postgres, DMY / MDY
+ self._order = ORDER_DMY if ds[10] == b'D' else ORDER_MDY
+ else:
+ raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ cdef object _error_date(self, const char *data, str msg):
+ s = bytes(data).decode("utf8", "replace")
+ if s == "infinity" or len(s.split()[0]) > 10:
+ raise e.DataError(f"date too large (after year 10K): {s!r}") from None
+ elif s == "-infinity" or "BC" in s:
+ raise e.DataError(f"date too small (before year 1): {s!r}") from None
+ else:
+ raise e.DataError(f"can't parse date {s!r}: {msg}") from None
+
+ cdef object cload(self, const char *data, size_t length):
+ if length != 10:
+ self._error_date(data, "unexpected length")
+
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+
+ cdef const char *ptr
+ cdef const char *end = data + length
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse date {s!r}")
+
+ try:
+ if self._order == ORDER_YMD:
+ return cdt.date_new(vals[0], vals[1], vals[2])
+ elif self._order == ORDER_DMY:
+ return cdt.date_new(vals[2], vals[1], vals[0])
+ else:
+ return cdt.date_new(vals[2], vals[0], vals[1])
+ except ValueError as ex:
+ self._error_date(data, str(ex))
+
+
+@cython.final
+cdef class DateBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int days = endian.be32toh((<uint32_t *>data)[0])
+ cdef object pydays = days + PG_DATE_EPOCH_DAYS
+ try:
+ return PyObject_CallFunctionObjArgs(
+ date_fromordinal, <PyObject *>pydays, NULL)
+ except ValueError:
+ if days < PY_DATE_MIN_DAYS:
+ raise e.DataError("date too small (before year 1)") from None
+ else:
+ raise e.DataError("date too large (after year 10K)") from None
+
+
+@cython.final
+cdef class TimeLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+ cdef const char *end = data + length
+
+ # Parse the first 3 groups of digits
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse time {s!r}")
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ try:
+ return cdt.time_new(vals[0], vals[1], vals[2], us, None)
+ except ValueError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse time {s!r}: {ex}") from None
+
+
+@cython.final
+cdef class TimeBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int h, m, s, us
+
+ with cython.cdivision(True):
+ us = val % 1_000_000
+ val //= 1_000_000
+
+ s = val % 60
+ val //= 60
+
+ m = val % 60
+ h = <int>(val // 60)
+
+ try:
+ return cdt.time_new(h, m, s, us, None)
+ except ValueError:
+ raise e.DataError(
+ f"time not supported by Python: hour={h}"
+ ) from None
+
+
+@cython.final
+cdef class TimetzLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+ cdef const char *end = data + length
+
+ # Parse the first 3 groups of digits (time)
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse timetz {s!r}")
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Parse the timezone
+ cdef int offsecs = _parse_timezone_to_seconds(&ptr, end)
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse timetz {s!r}")
+
+ tz = _timezone_from_seconds(offsecs)
+ try:
+ return cdt.time_new(vals[0], vals[1], vals[2], us, tz)
+ except ValueError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse timetz {s!r}: {ex}") from None
+
+
+@cython.final
+cdef class TimetzBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int32_t off = endian.be32toh((<uint32_t *>(data + sizeof(int64_t)))[0])
+ cdef int h, m, s, us
+
+ with cython.cdivision(True):
+ us = val % 1_000_000
+ val //= 1_000_000
+
+ s = val % 60
+ val //= 60
+
+ m = val % 60
+ h = <int>(val // 60)
+
+ tz = _timezone_from_seconds(-off)
+ try:
+ return cdt.time_new(h, m, s, us, tz)
+ except ValueError:
+ raise e.DataError(
+ f"time not supported by Python: hour={h}"
+ ) from None
+
+
+@cython.final
+cdef class TimestampLoader(CLoader):
+
+ format = PQ_TEXT
+ cdef int _order
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_datestyle(self._pgconn)
+ if ds[0] == b'I': # ISO
+ self._order = ORDER_YMD
+ elif ds[0] == b'G': # German
+ self._order = ORDER_DMY
+ elif ds[0] == b'S': # SQL, DMY / MDY
+ self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY
+ elif ds[0] == b'P': # Postgres, DMY / MDY
+ self._order = ORDER_PGDM if ds[10] == b'D' else ORDER_PGMD
+ else:
+ raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef const char *end = data + length
+ if end[-1] == b'C': # ends with BC
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ if self._order == ORDER_PGDM or self._order == ORDER_PGMD:
+ return self._cload_pg(data, end)
+
+ cdef int vals[6]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+
+ # Parse the first 6 groups of digits (date and time)
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Resolve the YMD order
+ cdef int y, m, d
+ if self._order == ORDER_YMD:
+ y, m, d = vals[0], vals[1], vals[2]
+ elif self._order == ORDER_DMY:
+ d, m, y = vals[0], vals[1], vals[2]
+ else: # self._order == ORDER_MDY
+ m, d, y = vals[0], vals[1], vals[2]
+
+ try:
+ return cdt.datetime_new(
+ y, m, d, vals[3], vals[4], vals[5], us, None)
+ except ValueError as ex:
+ raise _get_timestamp_load_error(self._pgconn, data, ex) from None
+
+ cdef object _cload_pg(self, const char *data, const char *end):
+ cdef int vals[4]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+
+ # Find Wed Jun 02 or Wed 02 Jun
+ cdef char *seps[3]
+ seps[0] = strchr(data, b' ')
+ seps[1] = strchr(seps[0] + 1, b' ') if seps[0] != NULL else NULL
+ seps[2] = strchr(seps[1] + 1, b' ') if seps[1] != NULL else NULL
+ if seps[2] == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the following 3 groups of digits (time)
+ ptr = _parse_date_values(seps[2] + 1, end, vals, 3)
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Parse the year
+ ptr = _parse_date_values(ptr + 1, end, vals + 3, 1)
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Resolve the MD order
+ cdef int m, d
+ try:
+ if self._order == ORDER_PGDM:
+ d = int(seps[0][1 : seps[1] - seps[0]])
+ m = _month_abbr[seps[1][1 : seps[2] - seps[1]]]
+ else: # self._order == ORDER_PGMD
+ m = _month_abbr[seps[0][1 : seps[1] - seps[0]]]
+ d = int(seps[1][1 : seps[2] - seps[1]])
+ except (KeyError, ValueError) as ex:
+ raise _get_timestamp_load_error(self._pgconn, data, ex) from None
+
+ try:
+ return cdt.datetime_new(
+ vals[3], m, d, vals[0], vals[1], vals[2], us, None)
+ except ValueError as ex:
+ raise _get_timestamp_load_error(self._pgconn, data, ex) from None
+
+
+@cython.final
+cdef class TimestampBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int64_t micros, secs, days
+
+ # Work only with positive values as the cdivision behaves differently
+ # with negative values, and cdivision=False adds overhead.
+ cdef int64_t aval = val if val >= 0 else -val
+
+ # Group the micros in biggers stuff or timedelta_new might overflow
+ with cython.cdivision(True):
+ secs = aval // 1_000_000
+ micros = aval % 1_000_000
+
+ days = secs // 86_400
+ secs %= 86_400
+
+ try:
+ delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros)
+ if val > 0:
+ return pg_datetime_epoch + delta
+ else:
+ return pg_datetime_epoch - delta
+
+ except OverflowError:
+ if val <= 0:
+ raise e.DataError("timestamp too small (before year 1)") from None
+ else:
+ raise e.DataError("timestamp too large (after year 10K)") from None
+
+
+cdef class _BaseTimestamptzLoader(CLoader):
+ cdef object _time_zone
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+ self._time_zone = _timezone_from_connection(self._pgconn)
+
+
+@cython.final
+cdef class TimestamptzLoader(_BaseTimestamptzLoader):
+
+ format = PQ_TEXT
+ cdef int _order
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_datestyle(self._pgconn)
+ if ds[0] == b'I': # ISO
+ self._order = ORDER_YMD
+ else: # Not true, but any non-YMD will do.
+ self._order = ORDER_DMY
+
+ cdef object cload(self, const char *data, size_t length):
+ if self._order != ORDER_YMD:
+ return self._cload_notimpl(data, length)
+
+ cdef const char *end = data + length
+ if end[-1] == b'C': # ends with BC
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ cdef int vals[6]
+ memset(vals, 0, sizeof(vals))
+
+ # Parse the first 6 groups of digits (date and time)
+ cdef const char *ptr
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Resolve the YMD order
+ cdef int y, m, d
+ if self._order == ORDER_YMD:
+ y, m, d = vals[0], vals[1], vals[2]
+ elif self._order == ORDER_DMY:
+ d, m, y = vals[0], vals[1], vals[2]
+ else: # self._order == ORDER_MDY
+ m, d, y = vals[0], vals[1], vals[2]
+
+ # Parse the timezone
+ cdef int offsecs = _parse_timezone_to_seconds(&ptr, end)
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ tzoff = cdt.timedelta_new(0, offsecs, 0)
+
+ # The return value is a datetime with the timezone of the connection
+ # (in order to be consistent with the binary loader, which is the only
+ # thing it can return). So create a temporary datetime object, in utc,
+ # shift it by the offset parsed from the timestamp, and then move it to
+ # the connection timezone.
+ dt = None
+ try:
+ dt = cdt.datetime_new(
+ y, m, d, vals[3], vals[4], vals[5], us, timezone_utc)
+ dt -= tzoff
+ return PyObject_CallFunctionObjArgs(datetime_astimezone,
+ <PyObject *>dt, <PyObject *>self._time_zone, NULL)
+ except OverflowError as ex:
+ # If we have created the temporary 'dt' it means that we have a
+ # datetime close to max, the shift pushed it past max, overflowing.
+ # In this case return the datetime in a fixed offset timezone.
+ if dt is not None:
+ return dt.replace(tzinfo=timezone(tzoff))
+ else:
+ ex1 = ex
+ except ValueError as ex:
+ ex1 = ex
+
+ raise _get_timestamp_load_error(self._pgconn, data, ex1) from None
+
+ cdef object _cload_notimpl(self, const char *data, size_t length):
+ s = bytes(data)[:length].decode("utf8", "replace")
+ ds = _get_datestyle(self._pgconn).decode()
+ raise NotImplementedError(
+ f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
+ )
+
+
+@cython.final
+cdef class TimestamptzBinaryLoader(_BaseTimestamptzLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int64_t micros, secs, days
+
+ # Work only with positive values as the cdivision behaves differently
+ # with negative values, and cdivision=False adds overhead.
+ cdef int64_t aval = val if val >= 0 else -val
+
+ # Group the micros in biggers stuff or timedelta_new might overflow
+ with cython.cdivision(True):
+ secs = aval // 1_000_000
+ micros = aval % 1_000_000
+
+ days = secs // 86_400
+ secs %= 86_400
+
+ try:
+ delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros)
+ if val > 0:
+ dt = pg_datetimetz_epoch + delta
+ else:
+ dt = pg_datetimetz_epoch - delta
+ return PyObject_CallFunctionObjArgs(datetime_astimezone,
+ <PyObject *>dt, <PyObject *>self._time_zone, NULL)
+
+ except OverflowError:
+ # If we were asked about a timestamp which would overflow in UTC,
+ # but not in the desired timezone (e.g. datetime.max at Chicago
+ # timezone) we can still save the day by shifting the value by the
+ # timezone offset and then replacing the timezone.
+ if self._time_zone is not None:
+ utcoff = self._time_zone.utcoffset(
+ datetime.min if val < 0 else datetime.max
+ )
+ if utcoff:
+ usoff = 1_000_000 * int(utcoff.total_seconds())
+ try:
+ ts = pg_datetime_epoch + timedelta(
+ microseconds=val + usoff
+ )
+ except OverflowError:
+ pass # will raise downstream
+ else:
+ return ts.replace(tzinfo=self._time_zone)
+
+ if val <= 0:
+ raise e.DataError(
+ "timestamp too small (before year 1)"
+ ) from None
+ else:
+ raise e.DataError(
+ "timestamp too large (after year 10K)"
+ ) from None
+
+
+@cython.final
+cdef class IntervalLoader(CLoader):
+
+ format = PQ_TEXT
+ cdef int _style
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_intervalstyle(self._pgconn)
+ if ds[0] == b'p' and ds[8] == 0: # postgres
+ self._style = INTERVALSTYLE_POSTGRES
+ else: # iso_8601, sql_standard, postgres_verbose
+ self._style = INTERVALSTYLE_OTHERS
+
+ cdef object cload(self, const char *data, size_t length):
+ if self._style == INTERVALSTYLE_OTHERS:
+ return self._cload_notimpl(data, length)
+
+ cdef int days = 0, secs = 0, us = 0
+ cdef char sign
+ cdef int val
+ cdef const char *ptr = data
+ cdef const char *sep
+ cdef const char *end = ptr + length
+
+ # If there are spaces, there is a [+|-]n [days|months|years]
+ while True:
+ if ptr[0] == b'-' or ptr[0] == b'+':
+ sign = ptr[0]
+ ptr += 1
+ else:
+ sign = 0
+
+ sep = strchr(ptr, b' ')
+ if sep == NULL or sep > end:
+ break
+
+ val = 0
+ ptr = _parse_date_values(ptr, end, &val, 1)
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}")
+
+ if sign == b'-':
+ val = -val
+
+ if ptr[1] == b'y':
+ days = 365 * val
+ elif ptr[1] == b'm':
+ days = 30 * val
+ elif ptr[1] == b'd':
+ days = val
+ else:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}")
+
+ # Skip the date part word.
+ ptr = strchr(ptr + 1, b' ')
+ if ptr != NULL and ptr < end:
+ ptr += 1
+ else:
+ break
+
+ # Parse the time part. An eventual sign was already consumed in the loop
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+ if ptr != NULL:
+ ptr = _parse_date_values(ptr, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}")
+
+ secs = vals[2] + 60 * (vals[1] + 60 * vals[0])
+
+ if ptr[0] == b'.':
+ ptr = _parse_micros(ptr + 1, &us)
+
+ if sign == b'-':
+ secs = -secs
+ us = -us
+
+ try:
+ return cdt.timedelta_new(days, secs, us)
+ except OverflowError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}: {ex}") from None
+
+ cdef object _cload_notimpl(self, const char *data, size_t length):
+ s = bytes(data).decode("utf8", "replace")
+ style = _get_intervalstyle(self._pgconn).decode()
+ raise NotImplementedError(
+ f"can't parse interval with IntervalStyle {style!r}: {s!r}"
+ )
+
+
+@cython.final
+cdef class IntervalBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int32_t days = endian.be32toh(
+ (<uint32_t *>(data + sizeof(int64_t)))[0])
+ cdef int32_t months = endian.be32toh(
+ (<uint32_t *>(data + sizeof(int64_t) + sizeof(int32_t)))[0])
+
+ cdef int years
+ with cython.cdivision(True):
+ if months > 0:
+ years = months // 12
+ months %= 12
+ days += 30 * months + 365 * years
+ elif months < 0:
+ months = -months
+ years = months // 12
+ months %= 12
+ days -= 30 * months + 365 * years
+
+ # Work only with positive values as the cdivision behaves differently
+ # with negative values, and cdivision=False adds overhead.
+ cdef int64_t aval = val if val >= 0 else -val
+ cdef int us, ussecs, usdays
+
+ # Group the micros in biggers stuff or timedelta_new might overflow
+ with cython.cdivision(True):
+ ussecs = <int>(aval // 1_000_000)
+ us = aval % 1_000_000
+
+ usdays = ussecs // 86_400
+ ussecs %= 86_400
+
+ if val < 0:
+ ussecs = -ussecs
+ usdays = -usdays
+ us = -us
+
+ try:
+ return cdt.timedelta_new(days + usdays, ussecs, us)
+ except OverflowError as ex:
+ raise e.DataError(f"can't parse interval: {ex}")
+
+
+cdef const char *_parse_date_values(
+ const char *ptr, const char *end, int *vals, int nvals
+):
+ """
+ Parse *nvals* numeric values separated by non-numeric chars.
+
+ Write the result in the *vals* array (assumed zeroed) starting from *start*.
+
+ Return the pointer at the separator after the final digit.
+ """
+ cdef int ival = 0
+ while ptr < end:
+ if b'0' <= ptr[0] <= b'9':
+ vals[ival] = vals[ival] * 10 + (ptr[0] - <char>b'0')
+ else:
+ ival += 1
+ if ival >= nvals:
+ break
+
+ ptr += 1
+
+ return ptr
+
+
+cdef const char *_parse_micros(const char *start, int *us):
+ """
+ Parse microseconds from a string.
+
+ Micros are assumed up to 6 digit chars separated by a non-digit.
+
+ Return the pointer at the separator after the final digit.
+ """
+ cdef const char *ptr = start
+ while ptr[0]:
+ if b'0' <= ptr[0] <= b'9':
+ us[0] = us[0] * 10 + (ptr[0] - <char>b'0')
+ else:
+ break
+
+ ptr += 1
+
+ # Pad the fraction of second to get millis
+ if us[0] and ptr - start < 6:
+ us[0] *= _uspad[ptr - start]
+
+ return ptr
+
+
+cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end):
+ """
+ Parse a timezone from a string, return Python timezone object.
+
+ Modify the buffer pointer to point at the first character after the
+ timezone parsed. In case of parse error make it NULL.
+ """
+ cdef const char *ptr = bufptr[0]
+ cdef char sgn = ptr[0]
+
+ # Parse at most three groups of digits
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+
+ ptr = _parse_date_values(ptr + 1, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ return 0
+
+ cdef int off = 60 * (60 * vals[0] + vals[1]) + vals[2]
+ return -off if sgn == b"-" else off
+
+
+cdef object _timezone_from_seconds(int sec, __cache={}):
+ cdef object pysec = sec
+ cdef PyObject *ptr = PyDict_GetItem(__cache, pysec)
+ if ptr != NULL:
+ return <object>ptr
+
+ delta = cdt.timedelta_new(0, sec, 0)
+ tz = timezone(delta)
+ __cache[pysec] = tz
+ return tz
+
+
+cdef object _get_timestamp_load_error(
+ pq.PGconn pgconn, const char *data, ex: Optional[Exception] = None
+):
+ s = bytes(data).decode("utf8", "replace")
+
+ def is_overflow(s):
+ if not s:
+ return False
+
+ ds = _get_datestyle(pgconn)
+ if not ds.startswith(b"P"): # Postgres
+ return len(s.split()[0]) > 10 # date is first token
+ else:
+ return len(s.split()[-1]) > 4 # year is last token
+
+ if s == "-infinity" or s.endswith("BC"):
+ return e.DataError("timestamp too small (before year 1): {s!r}")
+ elif s == "infinity" or is_overflow(s):
+ return e.DataError(f"timestamp too large (after year 10K): {s!r}")
+ else:
+ return e.DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}")
+
+
+cdef _timezones = {}
+_timezones[None] = timezone_utc
+_timezones[b"UTC"] = timezone_utc
+
+
+cdef object _timezone_from_connection(pq.PGconn pgconn):
+ """Return the Python timezone info of the connection's timezone."""
+ if pgconn is None:
+ return timezone_utc
+
+ cdef bytes tzname = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"TimeZone")
+ cdef PyObject *ptr = PyDict_GetItem(_timezones, tzname)
+ if ptr != NULL:
+ return <object>ptr
+
+ sname = tzname.decode() if tzname else "UTC"
+ try:
+ zi = ZoneInfo(sname)
+ except (KeyError, OSError):
+ logger.warning(
+ "unknown PostgreSQL timezone: %r; will use UTC", sname
+ )
+ zi = timezone_utc
+ except Exception as ex:
+ logger.warning(
+ "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
+ sname,
+ type(ex).__name__,
+ ex,
+ )
+ zi = timezone.utc
+
+ _timezones[tzname] = zi
+ return zi
+
+
+cdef const char *_get_datestyle(pq.PGconn pgconn):
+ cdef const char *ds
+ if pgconn is not None:
+ ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"DateStyle")
+ if ds is not NULL and ds[0]:
+ return ds
+
+ return b"ISO, DMY"
+
+
+cdef const char *_get_intervalstyle(pq.PGconn pgconn):
+ cdef const char *ds
+ if pgconn is not None:
+ ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"IntervalStyle")
+ if ds is not NULL and ds[0]:
+ return ds
+
+ return b"postgres"
diff --git a/psycopg_c/psycopg_c/types/numeric.pyx b/psycopg_c/psycopg_c/types/numeric.pyx
new file mode 100644
index 0000000..893bdc2
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/numeric.pyx
@@ -0,0 +1,715 @@
+"""
+Cython adapters for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+
+from libc.stdint cimport *
+from libc.string cimport memcpy, strlen
+from cpython.mem cimport PyMem_Free
+from cpython.dict cimport PyDict_GetItem, PyDict_SetItem
+from cpython.long cimport (
+ PyLong_FromString, PyLong_FromLong, PyLong_FromLongLong,
+ PyLong_FromUnsignedLong, PyLong_AsLongLong)
+from cpython.bytes cimport PyBytes_AsStringAndSize
+from cpython.float cimport PyFloat_FromDouble, PyFloat_AsDouble
+from cpython.unicode cimport PyUnicode_DecodeUTF8
+
+from decimal import Decimal, Context, DefaultContext
+
+from psycopg_c._psycopg cimport endian
+from psycopg import errors as e
+from psycopg._wrappers import Int2, Int4, Int8, IntNumeric
+
+cdef extern from "Python.h":
+ # work around https://github.com/cython/cython/issues/3909
+ double PyOS_string_to_double(
+ const char *s, char **endptr, PyObject *overflow_exception) except? -1.0
+ char *PyOS_double_to_string(
+ double val, char format_code, int precision, int flags, int *ptype
+ ) except NULL
+ int Py_DTSF_ADD_DOT_0
+ long long PyLong_AsLongLongAndOverflow(object pylong, int *overflow) except? -1
+
+ # Missing in cpython/unicode.pxd
+ const char *PyUnicode_AsUTF8(object unicode) except NULL
+
+
+# defined in numutils.c
+cdef extern from *:
+ """
+int pg_lltoa(int64_t value, char *a);
+#define MAXINT8LEN 20
+ """
+ int pg_lltoa(int64_t value, char *a)
+ const int MAXINT8LEN
+
+
+cdef class _NumberDumper(CDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_int_to_text(obj, rv, offset)
+
+ def quote(self, obj) -> bytearray:
+ cdef Py_ssize_t length
+
+ rv = PyByteArray_FromStringAndSize("", 0)
+ if obj >= 0:
+ length = self.cdump(obj, rv, 0)
+ else:
+ PyByteArray_Resize(rv, 23)
+ rv[0] = b' '
+ length = 1 + self.cdump(obj, rv, 1)
+
+ PyByteArray_Resize(rv, length)
+ return rv
+
+
+@cython.final
+cdef class Int2Dumper(_NumberDumper):
+
+ oid = oids.INT2_OID
+
+
+@cython.final
+cdef class Int4Dumper(_NumberDumper):
+
+ oid = oids.INT4_OID
+
+
+@cython.final
+cdef class Int8Dumper(_NumberDumper):
+
+ oid = oids.INT8_OID
+
+
+@cython.final
+cdef class IntNumericDumper(_NumberDumper):
+
+ oid = oids.NUMERIC_OID
+
+
+@cython.final
+cdef class Int2BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INT2_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int16_t *buf = <int16_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int16_t))
+ cdef int16_t val = <int16_t>PyLong_AsLongLong(obj)
+ # swap bytes if needed
+ cdef uint16_t *ptvar = <uint16_t *>(&val)
+ buf[0] = endian.htobe16(ptvar[0])
+ return sizeof(int16_t)
+
+
+@cython.final
+cdef class Int4BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INT4_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int32_t *buf = <int32_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int32_t))
+ cdef int32_t val = <int32_t>PyLong_AsLongLong(obj)
+ # swap bytes if needed
+ cdef uint32_t *ptvar = <uint32_t *>(&val)
+ buf[0] = endian.htobe32(ptvar[0])
+ return sizeof(int32_t)
+
+
+@cython.final
+cdef class Int8BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INT8_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t *buf = <int64_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int64_t))
+ cdef int64_t val = PyLong_AsLongLong(obj)
+ # swap bytes if needed
+ cdef uint64_t *ptvar = <uint64_t *>(&val)
+ buf[0] = endian.htobe64(ptvar[0])
+ return sizeof(int64_t)
+
+
+cdef extern from *:
+ """
+/* Ratio between number of bits required to store a number and number of pg
+ * decimal digits required (log(2) / log(10_000)).
+ */
+#define BIT_PER_PGDIGIT 0.07525749891599529
+
+/* decimal digits per Postgres "digit" */
+#define DEC_DIGITS 4
+
+#define NUMERIC_POS 0x0000
+#define NUMERIC_NEG 0x4000
+#define NUMERIC_NAN 0xC000
+#define NUMERIC_PINF 0xD000
+#define NUMERIC_NINF 0xF000
+"""
+ const double BIT_PER_PGDIGIT
+ const int DEC_DIGITS
+ const int NUMERIC_POS
+ const int NUMERIC_NEG
+ const int NUMERIC_NAN
+ const int NUMERIC_PINF
+ const int NUMERIC_NINF
+
+
+@cython.final
+cdef class IntNumericBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_int_to_numeric_binary(obj, rv, offset)
+
+
+cdef class IntDumper(_NumberDumper):
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ raise TypeError(
+ f"{type(self).__name__} is a dispatcher to other dumpers:"
+ " dump() is not supposed to be called"
+ )
+
+ cpdef get_key(self, obj, format):
+ cdef long long val
+ cdef int overflow
+
+ val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+ if overflow:
+ return IntNumeric
+
+ if INT32_MIN <= obj <= INT32_MAX:
+ if INT16_MIN <= obj <= INT16_MAX:
+ return Int2
+ else:
+ return Int4
+ else:
+ if INT64_MIN <= obj <= INT64_MAX:
+ return Int8
+ else:
+ return IntNumeric
+
+ _int2_dumper = Int2Dumper
+ _int4_dumper = Int4Dumper
+ _int8_dumper = Int8Dumper
+ _int_numeric_dumper = IntNumericDumper
+
+ cpdef upgrade(self, obj, format):
+ cdef long long val
+ cdef int overflow
+
+ val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+ if overflow:
+ return self._int_numeric_dumper(IntNumeric)
+
+ if INT32_MIN <= obj <= INT32_MAX:
+ if INT16_MIN <= obj <= INT16_MAX:
+ return self._int2_dumper(Int2)
+ else:
+ return self._int4_dumper(Int4)
+ else:
+ if INT64_MIN <= obj <= INT64_MAX:
+ return self._int8_dumper(Int8)
+ else:
+ return self._int_numeric_dumper(IntNumeric)
+
+
+@cython.final
+cdef class IntBinaryDumper(IntDumper):
+
+ format = PQ_BINARY
+
+ _int2_dumper = Int2BinaryDumper
+ _int4_dumper = Int4BinaryDumper
+ _int8_dumper = Int8BinaryDumper
+ _int_numeric_dumper = IntNumericBinaryDumper
+
+
+@cython.final
+cdef class IntLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ # if the number ends with a 0 we don't need a copy
+ if data[length] == b'\0':
+ return PyLong_FromString(data, NULL, 10)
+
+ # Otherwise we have to copy it aside
+ if length > MAXINT8LEN:
+ raise ValueError("string too big for an int")
+
+ cdef char[MAXINT8LEN + 1] buf
+ memcpy(buf, data, length)
+ buf[length] = 0
+ return PyLong_FromString(buf, NULL, 10)
+
+
+
+@cython.final
+cdef class Int2BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromLong(<int16_t>endian.be16toh((<uint16_t *>data)[0]))
+
+
+@cython.final
+cdef class Int4BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromLong(<int32_t>endian.be32toh((<uint32_t *>data)[0]))
+
+
+@cython.final
+cdef class Int8BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromLongLong(<int64_t>endian.be64toh((<uint64_t *>data)[0]))
+
+
+@cython.final
+cdef class OidBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromUnsignedLong(endian.be32toh((<uint32_t *>data)[0]))
+
+
+cdef class _FloatDumper(CDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef double d = PyFloat_AsDouble(obj)
+ cdef char *out = PyOS_double_to_string(
+ d, b'r', 0, Py_DTSF_ADD_DOT_0, NULL)
+ cdef Py_ssize_t length = strlen(out)
+ cdef char *tgt = CDumper.ensure_size(rv, offset, length)
+ memcpy(tgt, out, length)
+ PyMem_Free(out)
+ return length
+
+ def quote(self, obj) -> bytes:
+ value = bytes(self.dump(obj))
+ cdef PyObject *ptr = PyDict_GetItem(_special_float, value)
+ if ptr != NULL:
+ return <object>ptr
+
+ return value if obj >= 0 else b" " + value
+
+cdef dict _special_float = {
+ b"inf": b"'Infinity'::float8",
+ b"-inf": b"'-Infinity'::float8",
+ b"nan": b"'NaN'::float8",
+}
+
+
+@cython.final
+cdef class FloatDumper(_FloatDumper):
+
+ oid = oids.FLOAT8_OID
+
+
+@cython.final
+cdef class Float4Dumper(_FloatDumper):
+
+ oid = oids.FLOAT4_OID
+
+
+@cython.final
+cdef class FloatBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.FLOAT8_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef double d = PyFloat_AsDouble(obj)
+ cdef uint64_t *intptr = <uint64_t *>&d
+ cdef uint64_t *buf = <uint64_t *>CDumper.ensure_size(
+ rv, offset, sizeof(uint64_t))
+ buf[0] = endian.htobe64(intptr[0])
+ return sizeof(uint64_t)
+
+
+@cython.final
+cdef class Float4BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.FLOAT4_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef float f = <float>PyFloat_AsDouble(obj)
+ cdef uint32_t *intptr = <uint32_t *>&f
+ cdef uint32_t *buf = <uint32_t *>CDumper.ensure_size(
+ rv, offset, sizeof(uint32_t))
+ buf[0] = endian.htobe32(intptr[0])
+ return sizeof(uint32_t)
+
+
+@cython.final
+cdef class FloatLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef char *endptr
+ cdef double d = PyOS_string_to_double(
+ data, &endptr, <PyObject *>OverflowError)
+ return PyFloat_FromDouble(d)
+
+
+@cython.final
+cdef class Float4BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef uint32_t asint = endian.be32toh((<uint32_t *>data)[0])
+ # avoid warning:
+ # dereferencing type-punned pointer will break strict-aliasing rules
+ cdef char *swp = <char *>&asint
+ return PyFloat_FromDouble((<float *>swp)[0])
+
+
+@cython.final
+cdef class Float8BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef uint64_t asint = endian.be64toh((<uint64_t *>data)[0])
+ cdef char *swp = <char *>&asint
+ return PyFloat_FromDouble((<double *>swp)[0])
+
+
+@cython.final
+cdef class DecimalDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_decimal_to_text(obj, rv, offset)
+
+ def quote(self, obj) -> bytes:
+ value = bytes(self.dump(obj))
+ cdef PyObject *ptr = PyDict_GetItem(_special_decimal, value)
+ if ptr != NULL:
+ return <object>ptr
+
+ return value if obj >= 0 else b" " + value
+
+cdef dict _special_decimal = {
+ b"Infinity": b"'Infinity'::numeric",
+ b"-Infinity": b"'-Infinity'::numeric",
+ b"NaN": b"'NaN'::numeric",
+}
+
+
+@cython.final
+cdef class NumericLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ s = PyUnicode_DecodeUTF8(<char *>data, length, NULL)
+ return Decimal(s)
+
+
+cdef dict _decimal_special = {
+ NUMERIC_NAN: Decimal("NaN"),
+ NUMERIC_PINF: Decimal("Infinity"),
+ NUMERIC_NINF: Decimal("-Infinity"),
+}
+
+cdef dict _contexts = {}
+for _i in range(DefaultContext.prec):
+ _contexts[_i] = DefaultContext
+
+
+@cython.final
+cdef class NumericBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+
+ cdef uint16_t *data16 = <uint16_t *>data
+ cdef uint16_t ndigits = endian.be16toh(data16[0])
+ cdef int16_t weight = <int16_t>endian.be16toh(data16[1])
+ cdef uint16_t sign = endian.be16toh(data16[2])
+ cdef uint16_t dscale = endian.be16toh(data16[3])
+ cdef int shift
+ cdef int i
+ cdef PyObject *pctx
+ cdef object key
+
+ if sign == NUMERIC_POS or sign == NUMERIC_NEG:
+ if length != (4 + ndigits) * sizeof(uint16_t):
+ raise e.DataError("bad ndigits in numeric binary representation")
+
+ val = 0
+ for i in range(ndigits):
+ val *= 10_000
+ val += endian.be16toh(data16[i + 4])
+
+ shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
+
+ key = (weight + 2) * DEC_DIGITS + dscale
+ pctx = PyDict_GetItem(_contexts, key)
+ if pctx == NULL:
+ ctx = Context(prec=key)
+ PyDict_SetItem(_contexts, key, ctx)
+ pctx = <PyObject *>ctx
+
+ return (
+ Decimal(val if sign == NUMERIC_POS else -val)
+ .scaleb(-dscale, <object>pctx)
+ .shift(shift, <object>pctx)
+ )
+ else:
+ try:
+ return _decimal_special[sign]
+ except KeyError:
+ raise e.DataError(f"bad value for numeric sign: 0x{sign:X}")
+
+
+@cython.final
+cdef class DecimalBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_decimal_to_numeric_binary(obj, rv, offset)
+
+
+@cython.final
+cdef class NumericDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ if isinstance(obj, int):
+ return dump_int_to_text(obj, rv, offset)
+ else:
+ return dump_decimal_to_text(obj, rv, offset)
+
+
+@cython.final
+cdef class NumericBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ if isinstance(obj, int):
+ return dump_int_to_numeric_binary(obj, rv, offset)
+ else:
+ return dump_decimal_to_numeric_binary(obj, rv, offset)
+
+
+cdef Py_ssize_t dump_decimal_to_text(obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *src
+ cdef Py_ssize_t length
+ cdef char *buf
+
+ b = bytes(str(obj), "utf-8")
+ PyBytes_AsStringAndSize(b, &src, &length)
+
+ if src[0] != b's':
+ buf = CDumper.ensure_size(rv, offset, length)
+ memcpy(buf, src, length)
+
+ else: # convert sNaN to NaN
+ length = 3 # NaN
+ buf = CDumper.ensure_size(rv, offset, length)
+ memcpy(buf, b"NaN", length)
+
+ return length
+
+
+cdef extern from *:
+ """
+/* Weights of py digits into a pg digit according to their positions. */
+static const int pydigit_weights[] = {1000, 100, 10, 1};
+"""
+ const int[4] pydigit_weights
+
+
+@cython.cdivision(True)
+cdef Py_ssize_t dump_decimal_to_numeric_binary(
+ obj, bytearray rv, Py_ssize_t offset
+) except -1:
+
+ # TODO: this implementation is about 30% slower than the text dump.
+ # This might be probably optimised by accessing the C structure of
+ # the Decimal object, if available, which would save the creation of
+ # several intermediate Python objects (the DecimalTuple, the digits
+ # tuple, and then accessing them).
+
+ cdef object t = obj.as_tuple()
+ cdef int sign = t[0]
+ cdef tuple digits = t[1]
+ cdef uint16_t *buf
+ cdef Py_ssize_t length
+
+ cdef object pyexp = t[2]
+ cdef const char *bexp
+ if not isinstance(pyexp, int):
+ # Handle inf, nan
+ length = 4 * sizeof(uint16_t)
+ buf = <uint16_t *>CDumper.ensure_size(rv, offset, length)
+ buf[0] = 0
+ buf[1] = 0
+ buf[3] = 0
+ bexp = PyUnicode_AsUTF8(pyexp)
+ if bexp[0] == b'n' or bexp[0] == b'N':
+ buf[2] = endian.htobe16(NUMERIC_NAN)
+ elif bexp[0] == b'F':
+ if sign:
+ buf[2] = endian.htobe16(NUMERIC_NINF)
+ else:
+ buf[2] = endian.htobe16(NUMERIC_PINF)
+ else:
+ raise e.DataError(f"unexpected decimal exponent: {pyexp}")
+ return length
+
+ cdef int exp = pyexp
+ cdef uint16_t ndigits = <uint16_t>len(digits)
+
+ # Find the last nonzero digit
+ cdef int nzdigits = ndigits
+ while nzdigits > 0 and digits[nzdigits - 1] == 0:
+ nzdigits -= 1
+
+ cdef uint16_t dscale
+ if exp <= 0:
+ dscale = -exp
+ else:
+ dscale = 0
+ # align the py digits to the pg digits if there's some py exponent
+ ndigits += exp % DEC_DIGITS
+
+ if nzdigits == 0:
+ length = 4 * sizeof(uint16_t)
+ buf = <uint16_t *>CDumper.ensure_size(rv, offset, length)
+ buf[0] = 0 # ndigits
+ buf[1] = 0 # weight
+ buf[2] = endian.htobe16(NUMERIC_POS) # sign
+ buf[3] = endian.htobe16(dscale)
+ return length
+
+ # Equivalent of 0-padding left to align the py digits to the pg digits
+ # but without changing the digits tuple.
+ cdef int wi = 0
+ cdef int mod = (ndigits - dscale) % DEC_DIGITS
+ if mod < 0:
+ # the difference between C and Py % operator
+ mod += 4
+ if mod:
+ wi = DEC_DIGITS - mod
+ ndigits += wi
+
+ cdef int tmp = nzdigits + wi
+ cdef int pgdigits = tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1)
+ length = (pgdigits + 4) * sizeof(uint16_t)
+ buf = <uint16_t*>CDumper.ensure_size(rv, offset, length)
+ buf[0] = endian.htobe16(pgdigits)
+ buf[1] = endian.htobe16(<int16_t>((ndigits + exp) // DEC_DIGITS - 1))
+ buf[2] = endian.htobe16(NUMERIC_NEG) if sign else endian.htobe16(NUMERIC_POS)
+ buf[3] = endian.htobe16(dscale)
+
+ cdef uint16_t pgdigit = 0
+ cdef int bi = 4
+ for i in range(nzdigits):
+ pgdigit += pydigit_weights[wi] * <int>(digits[i])
+ wi += 1
+ if wi >= DEC_DIGITS:
+ buf[bi] = endian.htobe16(pgdigit)
+ pgdigit = wi = 0
+ bi += 1
+
+ if pgdigit:
+ buf[bi] = endian.htobe16(pgdigit)
+
+ return length
+
+
+cdef Py_ssize_t dump_int_to_text(obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef long long val
+ cdef int overflow
+ cdef char *buf
+ cdef char *src
+ cdef Py_ssize_t length
+
+ # Ensure an int or a subclass. The 'is' type check is fast.
+ # Passing a float must give an error, but passing an Enum should work.
+ if type(obj) is not int and not isinstance(obj, int):
+ raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
+
+ val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+ if not overflow:
+ buf = CDumper.ensure_size(rv, offset, MAXINT8LEN + 1)
+ length = pg_lltoa(val, buf)
+ else:
+ b = bytes(str(obj), "utf-8")
+ PyBytes_AsStringAndSize(b, &src, &length)
+ buf = CDumper.ensure_size(rv, offset, length)
+ memcpy(buf, src, length)
+
+ return length
+
+
+cdef Py_ssize_t dump_int_to_numeric_binary(obj, bytearray rv, Py_ssize_t offset) except -1:
+ # Calculate the number of PG digits required to store the number
+ cdef uint16_t ndigits
+ ndigits = <uint16_t>((<int>obj.bit_length()) * BIT_PER_PGDIGIT) + 1
+
+ cdef uint16_t sign = NUMERIC_POS
+ if obj < 0:
+ sign = NUMERIC_NEG
+ obj = -obj
+
+ cdef Py_ssize_t length = sizeof(uint16_t) * (ndigits + 4)
+ cdef uint16_t *buf
+ buf = <uint16_t *><void *>CDumper.ensure_size(rv, offset, length)
+ buf[0] = endian.htobe16(ndigits)
+ buf[1] = endian.htobe16(ndigits - 1) # weight
+ buf[2] = endian.htobe16(sign)
+ buf[3] = 0 # dscale
+
+ cdef int i = 4 + ndigits - 1
+ cdef uint16_t rem
+ while obj:
+ rem = obj % 10000
+ obj //= 10000
+ buf[i] = endian.htobe16(rem)
+ i -= 1
+ while i > 3:
+ buf[i] = 0
+ i -= 1
+
+ return length
diff --git a/psycopg_c/psycopg_c/types/numutils.c b/psycopg_c/psycopg_c/types/numutils.c
new file mode 100644
index 0000000..4be7108
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/numutils.c
@@ -0,0 +1,243 @@
+/*
+ * Utilities to deal with numbers.
+ *
+ * Copyright (C) 2020 The Psycopg Team
+ * Portions Copyright (c) 1996-2020, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ */
+
+#include <stdint.h>
+#include <string.h>
+
+#include "pg_config.h"
+
+
+/*
+ * 64-bit integers
+ */
+#ifdef HAVE_LONG_INT_64
+/* Plain "long int" fits, use it */
+
+# ifndef HAVE_INT64
+typedef long int int64;
+# endif
+# ifndef HAVE_UINT64
+typedef unsigned long int uint64;
+# endif
+# define INT64CONST(x) (x##L)
+# define UINT64CONST(x) (x##UL)
+#elif defined(HAVE_LONG_LONG_INT_64)
+/* We have working support for "long long int", use that */
+
+# ifndef HAVE_INT64
+typedef long long int int64;
+# endif
+# ifndef HAVE_UINT64
+typedef unsigned long long int uint64;
+# endif
+# define INT64CONST(x) (x##LL)
+# define UINT64CONST(x) (x##ULL)
+#else
+/* neither HAVE_LONG_INT_64 nor HAVE_LONG_LONG_INT_64 */
+# error must have a working 64-bit integer datatype
+#endif
+
+
+#ifndef HAVE__BUILTIN_CLZ
+static const uint8_t pg_leftmost_one_pos[256] = {
+ 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
+};
+#endif
+
+static const char DIGIT_TABLE[200] = {
+ '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0',
+ '7', '0', '8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4',
+ '1', '5', '1', '6', '1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2',
+ '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9',
+ '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3',
+ '7', '3', '8', '3', '9', '4', '0', '4', '1', '4', '2', '4', '3', '4', '4',
+ '4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0', '5', '1', '5',
+ '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9',
+ '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6',
+ '7', '6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4',
+ '7', '5', '7', '6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8',
+ '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9',
+ '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9',
+ '7', '9', '8', '9', '9'
+};
+
+
+/*
+ * pg_leftmost_one_pos64
+ * As above, but for a 64-bit word.
+ */
+static inline int
+pg_leftmost_one_pos64(uint64_t word)
+{
+#ifdef HAVE__BUILTIN_CLZ
+#if defined(HAVE_LONG_INT_64)
+ return 63 - __builtin_clzl(word);
+#elif defined(HAVE_LONG_LONG_INT_64)
+ return 63 - __builtin_clzll(word);
+#else
+#error must have a working 64-bit integer datatype
+#endif
+#else /* !HAVE__BUILTIN_CLZ */
+ int shift = 64 - 8;
+
+ while ((word >> shift) == 0)
+ shift -= 8;
+
+ return shift + pg_leftmost_one_pos[(word >> shift) & 255];
+#endif /* HAVE__BUILTIN_CLZ */
+}
+
+
+static inline int
+decimalLength64(const uint64_t v)
+{
+ int t;
+ static const uint64_t PowersOfTen[] = {
+ UINT64CONST(1), UINT64CONST(10),
+ UINT64CONST(100), UINT64CONST(1000),
+ UINT64CONST(10000), UINT64CONST(100000),
+ UINT64CONST(1000000), UINT64CONST(10000000),
+ UINT64CONST(100000000), UINT64CONST(1000000000),
+ UINT64CONST(10000000000), UINT64CONST(100000000000),
+ UINT64CONST(1000000000000), UINT64CONST(10000000000000),
+ UINT64CONST(100000000000000), UINT64CONST(1000000000000000),
+ UINT64CONST(10000000000000000), UINT64CONST(100000000000000000),
+ UINT64CONST(1000000000000000000), UINT64CONST(10000000000000000000)
+ };
+
+ /*
+ * Compute base-10 logarithm by dividing the base-2 logarithm by a
+ * good-enough approximation of the base-2 logarithm of 10
+ */
+ t = (pg_leftmost_one_pos64(v) + 1) * 1233 / 4096;
+ return t + (v >= PowersOfTen[t]);
+}
+
+
+/*
+ * Get the decimal representation, not NUL-terminated, and return the length of
+ * same. Caller must ensure that a points to at least MAXINT8LEN bytes.
+ */
+int
+pg_ulltoa_n(uint64_t value, char *a)
+{
+ int olength,
+ i = 0;
+ uint32_t value2;
+
+ /* Degenerate case */
+ if (value == 0)
+ {
+ *a = '0';
+ return 1;
+ }
+
+ olength = decimalLength64(value);
+
+ /* Compute the result string. */
+ while (value >= 100000000)
+ {
+ const uint64_t q = value / 100000000;
+ uint32_t value2 = (uint32_t) (value - 100000000 * q);
+
+ const uint32_t c = value2 % 10000;
+ const uint32_t d = value2 / 10000;
+ const uint32_t c0 = (c % 100) << 1;
+ const uint32_t c1 = (c / 100) << 1;
+ const uint32_t d0 = (d % 100) << 1;
+ const uint32_t d1 = (d / 100) << 1;
+
+ char *pos = a + olength - i;
+
+ value = q;
+
+ memcpy(pos - 2, DIGIT_TABLE + c0, 2);
+ memcpy(pos - 4, DIGIT_TABLE + c1, 2);
+ memcpy(pos - 6, DIGIT_TABLE + d0, 2);
+ memcpy(pos - 8, DIGIT_TABLE + d1, 2);
+ i += 8;
+ }
+
+ /* Switch to 32-bit for speed */
+ value2 = (uint32_t) value;
+
+ if (value2 >= 10000)
+ {
+ const uint32_t c = value2 - 10000 * (value2 / 10000);
+ const uint32_t c0 = (c % 100) << 1;
+ const uint32_t c1 = (c / 100) << 1;
+
+ char *pos = a + olength - i;
+
+ value2 /= 10000;
+
+ memcpy(pos - 2, DIGIT_TABLE + c0, 2);
+ memcpy(pos - 4, DIGIT_TABLE + c1, 2);
+ i += 4;
+ }
+ if (value2 >= 100)
+ {
+ const uint32_t c = (value2 % 100) << 1;
+ char *pos = a + olength - i;
+
+ value2 /= 100;
+
+ memcpy(pos - 2, DIGIT_TABLE + c, 2);
+ i += 2;
+ }
+ if (value2 >= 10)
+ {
+ const uint32_t c = value2 << 1;
+ char *pos = a + olength - i;
+
+ memcpy(pos - 2, DIGIT_TABLE + c, 2);
+ }
+ else
+ *a = (char) ('0' + value2);
+
+ return olength;
+}
+
+/*
+ * pg_lltoa: converts a signed 64-bit integer to its string representation and
+ * returns strlen(a).
+ *
+ * Caller must ensure that 'a' points to enough memory to hold the result
+ * (at least MAXINT8LEN + 1 bytes, counting a leading sign and trailing NUL).
+ */
+int
+pg_lltoa(int64_t value, char *a)
+{
+ uint64_t uvalue = value;
+ int len = 0;
+
+ if (value < 0)
+ {
+ uvalue = (uint64_t) 0 - uvalue;
+ a[len++] = '-';
+ }
+
+ len += pg_ulltoa_n(uvalue, a + len);
+ a[len] = '\0';
+ return len;
+}
diff --git a/psycopg_c/psycopg_c/types/string.pyx b/psycopg_c/psycopg_c/types/string.pyx
new file mode 100644
index 0000000..da18b01
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/string.pyx
@@ -0,0 +1,315 @@
+"""
+Cython adapters for textual types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+
+from libc.string cimport memcpy, memchr
+from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize
+from cpython.unicode cimport (
+ PyUnicode_AsEncodedString,
+ PyUnicode_AsUTF8String,
+ PyUnicode_CheckExact,
+ PyUnicode_Decode,
+ PyUnicode_DecodeUTF8,
+)
+
+from psycopg_c.pq cimport libpq, Escaping, _buffer_as_string_and_size
+
+from psycopg import errors as e
+from psycopg._encodings import pg2pyenc
+
+cdef extern from "Python.h":
+ const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL
+
+
+cdef class _BaseStrDumper(CDumper):
+ cdef int is_utf8
+ cdef char *encoding
+ cdef bytes _bytes_encoding # needed to keep `encoding` alive
+
+ def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+
+ self.is_utf8 = 0
+ self.encoding = "utf-8"
+ cdef const char *pgenc
+
+ if self._pgconn is not None:
+ pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding")
+ if pgenc == NULL or pgenc == b"UTF8":
+ self._bytes_encoding = b"utf-8"
+ self.is_utf8 = 1
+ else:
+ self._bytes_encoding = pg2pyenc(pgenc).encode()
+ if self._bytes_encoding == b"ascii":
+ self.is_utf8 = 1
+ self.encoding = PyBytes_AsString(self._bytes_encoding)
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ # the server will raise DataError subclass if the string contains 0x00
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ if self.is_utf8:
+ # Probably the fastest path, but doesn't work with subclasses
+ if PyUnicode_CheckExact(obj):
+ src = PyUnicode_AsUTF8AndSize(obj, &size)
+ else:
+ b = PyUnicode_AsUTF8String(obj)
+ PyBytes_AsStringAndSize(b, <char **>&src, &size)
+ else:
+ b = PyUnicode_AsEncodedString(obj, self.encoding, NULL)
+ PyBytes_AsStringAndSize(b, <char **>&src, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+cdef class _StrBinaryDumper(_BaseStrDumper):
+
+ format = PQ_BINARY
+
+
+@cython.final
+cdef class StrBinaryDumper(_StrBinaryDumper):
+
+ oid = oids.TEXT_OID
+
+
+@cython.final
+cdef class StrBinaryDumperVarchar(_StrBinaryDumper):
+
+ oid = oids.VARCHAR_OID
+
+
+@cython.final
+cdef class StrBinaryDumperName(_StrBinaryDumper):
+
+ oid = oids.NAME_OID
+
+
+cdef class _StrDumper(_BaseStrDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size = StrBinaryDumper.cdump(self, obj, rv, offset)
+
+ # Like the binary dump, but check for 0, or the string will be truncated
+ cdef const char *buf = PyByteArray_AS_STRING(rv)
+ if NULL != memchr(buf + offset, 0x00, size):
+ raise e.DataError(
+ "PostgreSQL text fields cannot contain NUL (0x00) bytes"
+ )
+ return size
+
+
+@cython.final
+cdef class StrDumper(_StrDumper):
+
+ oid = oids.TEXT_OID
+
+
+@cython.final
+cdef class StrDumperVarchar(_StrDumper):
+
+ oid = oids.VARCHAR_OID
+
+
+@cython.final
+cdef class StrDumperName(_StrDumper):
+
+ oid = oids.NAME_OID
+
+
+@cython.final
+cdef class StrDumperUnknown(_StrDumper):
+ pass
+
+
+cdef class _TextLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef int is_utf8
+ cdef char *encoding
+ cdef bytes _bytes_encoding # needed to keep `encoding` alive
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ self.is_utf8 = 0
+ self.encoding = "utf-8"
+ cdef const char *pgenc
+
+ if self._pgconn is not None:
+ pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding")
+ if pgenc == NULL or pgenc == b"UTF8":
+ self._bytes_encoding = b"utf-8"
+ self.is_utf8 = 1
+ else:
+ self._bytes_encoding = pg2pyenc(pgenc).encode()
+
+ if pgenc == b"SQL_ASCII":
+ self.encoding = NULL
+ else:
+ self.encoding = PyBytes_AsString(self._bytes_encoding)
+
+ cdef object cload(self, const char *data, size_t length):
+ if self.is_utf8:
+ return PyUnicode_DecodeUTF8(<char *>data, length, NULL)
+ elif self.encoding:
+ return PyUnicode_Decode(<char *>data, length, self.encoding, NULL)
+ else:
+ return data[:length]
+
+@cython.final
+cdef class TextLoader(_TextLoader):
+
+ format = PQ_TEXT
+
+
+@cython.final
+cdef class TextBinaryLoader(_TextLoader):
+
+ format = PQ_BINARY
+
+
+@cython.final
+cdef class BytesDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.BYTEA_OID
+
+ # 0: not set, 1: just single "'" quote, 3: " E'" quote
+ cdef int _qplen
+
+ def __cinit__(self):
+ self._qplen = 0
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+
+ cdef size_t len_out
+ cdef unsigned char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(obj, &ptr, &length)
+
+ if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
+ out = libpq.PQescapeByteaConn(
+ self._pgconn._pgconn_ptr, <unsigned char *>ptr, length, &len_out)
+ else:
+ out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out)
+
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {length} bytes"
+ )
+
+ len_out -= 1 # out includes final 0
+ cdef char *buf = CDumper.ensure_size(rv, offset, len_out)
+ memcpy(buf, out, len_out)
+ libpq.PQfreemem(out)
+ return len_out
+
+ def quote(self, obj):
+ cdef size_t len_out
+ cdef unsigned char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+ cdef const char *scs
+
+ escaped = self.dump(obj)
+ _buffer_as_string_and_size(escaped, &ptr, &length)
+
+ rv = PyByteArray_FromStringAndSize("", 0)
+
+ # We cannot use the base quoting because escape_bytea already returns
+ # the quotes content. if scs is off it will escape the backslashes in
+ # the format, otherwise it won't, but it doesn't tell us what quotes to
+ # use.
+ if self._pgconn is not None:
+ if not self._qplen:
+ scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr,
+ b"standard_conforming_strings")
+ if scs and scs[0] == b'o' and scs[1] == b"n": # == "on"
+ self._qplen = 1
+ else:
+ self._qplen = 3
+
+ PyByteArray_Resize(rv, length + self._qplen + 1) # Include quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ if self._qplen == 1:
+ ptr_out[0] = b"'"
+ else:
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + self._qplen, ptr, length)
+ ptr_out[length + self._qplen] = b"'"
+ return rv
+
+ # We don't have a connection, so someone is using us to generate a file
+ # to use off-line or something like that. PQescapeBytea, like its
+ # string counterpart, is not predictable whether it will escape
+ # backslashes.
+ PyByteArray_Resize(rv, length + 4) # Include quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + 3, ptr, length)
+ ptr_out[length + 3] = b"'"
+
+ esc = Escaping()
+ if esc.escape_bytea(b"\x00") == b"\\000":
+ rv = bytes(rv).replace(b"\\", b"\\\\")
+
+ return rv
+
+
+@cython.final
+cdef class BytesBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.BYTEA_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *src
+ cdef Py_ssize_t size;
+ _buffer_as_string_and_size(obj, &src, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class ByteaLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef size_t len_out
+ cdef unsigned char *out = libpq.PQunescapeBytea(
+ <const unsigned char *>data, &len_out)
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = out[:len_out]
+ libpq.PQfreemem(out)
+ return rv
+
+
+@cython.final
+cdef class ByteaBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return data[:length]
diff --git a/psycopg_c/psycopg_c/version.py b/psycopg_c/psycopg_c/version.py
new file mode 100644
index 0000000..5c989c2
--- /dev/null
+++ b/psycopg_c/psycopg_c/version.py
@@ -0,0 +1,11 @@
+"""
+psycopg-c distribution version file.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Use a versioning scheme as defined in
+# https://www.python.org/dev/peps/pep-0440/
+__version__ = "3.1.7"
+
+# also change psycopg/psycopg/version.py accordingly.
diff --git a/psycopg_c/pyproject.toml b/psycopg_c/pyproject.toml
new file mode 100644
index 0000000..f0d7a3f
--- /dev/null
+++ b/psycopg_c/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=49.2.0", "wheel>=0.37", "Cython>=3.0.0a11"]
+build-backend = "setuptools.build_meta"
diff --git a/psycopg_c/setup.cfg b/psycopg_c/setup.cfg
new file mode 100644
index 0000000..6c5c93c
--- /dev/null
+++ b/psycopg_c/setup.cfg
@@ -0,0 +1,57 @@
+[metadata]
+name = psycopg-c
+description = PostgreSQL database adapter for Python -- C optimisation distribution
+url = https://psycopg.org/psycopg3/
+author = Daniele Varrazzo
+author_email = daniele.varrazzo@gmail.com
+license = GNU Lesser General Public License v3 (LGPLv3)
+
+project_urls =
+ Homepage = https://psycopg.org/
+ Code = https://github.com/psycopg/psycopg
+ Issue Tracker = https://github.com/psycopg/psycopg/issues
+ Download = https://pypi.org/project/psycopg-c/
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
+ Operating System :: MacOS :: MacOS X
+ Operating System :: Microsoft :: Windows
+ Operating System :: POSIX
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Topic :: Database
+ Topic :: Database :: Front-Ends
+ Topic :: Software Development
+ Topic :: Software Development :: Libraries :: Python Modules
+
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+license_files = LICENSE.txt
+
+[options]
+python_requires = >= 3.7
+setup_requires = Cython >= 3.0.0a11
+packages = find:
+zip_safe = False
+
+[options.package_data]
+# NOTE: do not include .pyx files: they shouldn't be in the sdist
+# package, so that build is only performed from the .c files (which are
+# distributed instead).
+psycopg_c =
+ py.typed
+ *.pyi
+ *.pxd
+ _psycopg/*.pxd
+ pq/*.pxd
+
+# In the psycopg-binary distribution don't include cython-related files.
+psycopg_binary =
+ py.typed
+ *.pyi
diff --git a/psycopg_c/setup.py b/psycopg_c/setup.py
new file mode 100644
index 0000000..c6da3a1
--- /dev/null
+++ b/psycopg_c/setup.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python3
+"""
+PostgreSQL database adapter for Python - optimisation package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+import sys
+import subprocess as sp
+
+from setuptools import setup, Extension
+from distutils.command.build_ext import build_ext
+from distutils import log
+
+# Move to the directory of setup.py: executing this file from another location
+# (e.g. from the project root) will fail
+here = os.path.abspath(os.path.dirname(__file__))
+if os.path.abspath(os.getcwd()) != here:
+ os.chdir(here)
+
+with open("psycopg_c/version.py") as f:
+ data = f.read()
+ m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data)
+ if m is None:
+ raise Exception(f"cannot find version in {f.name}")
+ version = m.group(1)
+
+
+def get_config(what: str) -> str:
+ pg_config = "pg_config"
+ try:
+ out = sp.run([pg_config, f"--{what}"], stdout=sp.PIPE, check=True)
+ except Exception as e:
+ log.error(f"couldn't run {pg_config!r} --{what}: %s", e)
+ raise
+ else:
+ return out.stdout.strip().decode()
+
+
+class psycopg_build_ext(build_ext):
+ def finalize_options(self) -> None:
+ self._setup_ext_build()
+ super().finalize_options()
+
+ def _setup_ext_build(self) -> None:
+ cythonize = None
+
+ # In the sdist there are not .pyx, only c, so we don't need Cython.
+ # Otherwise Cython is a requirement and it is used to compile pyx to c.
+ if os.path.exists("psycopg_c/_psycopg.pyx"):
+ from Cython.Build import cythonize
+
+ # Add include and lib dir for the libpq.
+ includedir = get_config("includedir")
+ libdir = get_config("libdir")
+ for ext in self.distribution.ext_modules:
+ ext.include_dirs.append(includedir)
+ ext.library_dirs.append(libdir)
+
+ if sys.platform == "win32":
+ # For __imp_htons and others
+ ext.libraries.append("ws2_32")
+
+ if cythonize is not None:
+ for ext in self.distribution.ext_modules:
+ for i in range(len(ext.sources)):
+ base, fext = os.path.splitext(ext.sources[i])
+ if fext == ".c" and os.path.exists(base + ".pyx"):
+ ext.sources[i] = base + ".pyx"
+
+ self.distribution.ext_modules = cythonize(
+ self.distribution.ext_modules,
+ language_level=3,
+ compiler_directives={
+ "always_allow_keywords": False,
+ },
+ annotate=False, # enable to get an html view of the C module
+ )
+ else:
+ self.distribution.ext_modules = [pgext, pqext]
+
+
+# MSVC requires an explicit "libpq"
+libpq = "pq" if sys.platform != "win32" else "libpq"
+
+# Some details missing, to be finished by psycopg_build_ext.finalize_options
+pgext = Extension(
+ "psycopg_c._psycopg",
+ [
+ "psycopg_c/_psycopg.c",
+ "psycopg_c/types/numutils.c",
+ ],
+ libraries=[libpq],
+ include_dirs=[],
+)
+
+pqext = Extension(
+ "psycopg_c.pq",
+ ["psycopg_c/pq.c"],
+ libraries=[libpq],
+ include_dirs=[],
+)
+
+setup(
+ version=version,
+ ext_modules=[pgext, pqext],
+ cmdclass={"build_ext": psycopg_build_ext},
+)
diff --git a/psycopg_pool/.flake8 b/psycopg_pool/.flake8
new file mode 100644
index 0000000..2ae629c
--- /dev/null
+++ b/psycopg_pool/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
diff --git a/psycopg_pool/LICENSE.txt b/psycopg_pool/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/psycopg_pool/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/psycopg_pool/README.rst b/psycopg_pool/README.rst
new file mode 100644
index 0000000..6e6b32c
--- /dev/null
+++ b/psycopg_pool/README.rst
@@ -0,0 +1,24 @@
+Psycopg 3: PostgreSQL database adapter for Python - Connection Pool
+===================================================================
+
+This distribution contains the optional connection pool package
+`psycopg_pool`__.
+
+.. __: https://www.psycopg.org/psycopg3/docs/advanced/pool.html
+
+This package is kept separate from the main ``psycopg`` package because it is
+likely that it will follow a different release cycle.
+
+You can also install this package using::
+
+ pip install "psycopg[pool]"
+
+Please read `the project readme`__ and `the installation documentation`__ for
+more details.
+
+.. __: https://github.com/psycopg/psycopg#readme
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #installing-the-connection-pool
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg_pool/psycopg_pool/__init__.py b/psycopg_pool/psycopg_pool/__init__.py
new file mode 100644
index 0000000..e4d975f
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/__init__.py
@@ -0,0 +1,22 @@
+"""
+psycopg connection pool package
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from .pool import ConnectionPool
+from .pool_async import AsyncConnectionPool
+from .null_pool import NullConnectionPool
+from .null_pool_async import AsyncNullConnectionPool
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
+from .version import __version__ as __version__ # noqa: F401
+
+__all__ = [
+ "AsyncConnectionPool",
+ "AsyncNullConnectionPool",
+ "ConnectionPool",
+ "NullConnectionPool",
+ "PoolClosed",
+ "PoolTimeout",
+ "TooManyRequests",
+]
diff --git a/psycopg_pool/psycopg_pool/_compat.py b/psycopg_pool/psycopg_pool/_compat.py
new file mode 100644
index 0000000..9fb2b9b
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/_compat.py
@@ -0,0 +1,51 @@
+"""
+compatibility functions for different Python versions
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import sys
+import asyncio
+from typing import Any, Awaitable, Generator, Optional, Union, Type, TypeVar
+from typing_extensions import TypeAlias
+
+import psycopg.errors as e
+
+T = TypeVar("T")
+FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
+
+if sys.version_info >= (3, 8):
+ create_task = asyncio.create_task
+ Task = asyncio.Task
+
+else:
+
+ def create_task(
+ coro: FutureT[T], name: Optional[str] = None
+ ) -> "asyncio.Future[T]":
+ return asyncio.create_task(coro)
+
+ Task = asyncio.Future
+
+if sys.version_info >= (3, 9):
+ from collections import Counter, deque as Deque
+else:
+ from typing import Counter, Deque
+
+__all__ = [
+ "Counter",
+ "Deque",
+ "Task",
+ "create_task",
+]
+
+# Workaround for psycopg < 3.0.8.
+# Timeout on NullPool connection mignt not work correctly.
+try:
+ ConnectionTimeout: Type[e.OperationalError] = e.ConnectionTimeout
+except AttributeError:
+
+ class DummyConnectionTimeout(e.OperationalError):
+ pass
+
+ ConnectionTimeout = DummyConnectionTimeout
diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py
new file mode 100644
index 0000000..298ea68
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/base.py
@@ -0,0 +1,230 @@
+"""
+psycopg connection pool base class and functionalities.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from time import monotonic
+from random import random
+from typing import Any, Callable, Dict, Generic, Optional, Tuple
+
+from psycopg import errors as e
+from psycopg.abc import ConnectionType
+
+from .errors import PoolClosed
+from ._compat import Counter, Deque
+
+
+class BasePool(Generic[ConnectionType]):
+
+ # Used to generate pool names
+ _num_pool = 0
+
+ # Stats keys
+ _POOL_MIN = "pool_min"
+ _POOL_MAX = "pool_max"
+ _POOL_SIZE = "pool_size"
+ _POOL_AVAILABLE = "pool_available"
+ _REQUESTS_WAITING = "requests_waiting"
+ _REQUESTS_NUM = "requests_num"
+ _REQUESTS_QUEUED = "requests_queued"
+ _REQUESTS_WAIT_MS = "requests_wait_ms"
+ _REQUESTS_ERRORS = "requests_errors"
+ _USAGE_MS = "usage_ms"
+ _RETURNS_BAD = "returns_bad"
+ _CONNECTIONS_NUM = "connections_num"
+ _CONNECTIONS_MS = "connections_ms"
+ _CONNECTIONS_ERRORS = "connections_errors"
+ _CONNECTIONS_LOST = "connections_lost"
+
+ def __init__(
+ self,
+ conninfo: str = "",
+ *,
+ kwargs: Optional[Dict[str, Any]] = None,
+ min_size: int = 4,
+ max_size: Optional[int] = None,
+ open: bool = True,
+ name: Optional[str] = None,
+ timeout: float = 30.0,
+ max_waiting: int = 0,
+ max_lifetime: float = 60 * 60.0,
+ max_idle: float = 10 * 60.0,
+ reconnect_timeout: float = 5 * 60.0,
+ reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]] = None,
+ num_workers: int = 3,
+ ):
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ if not name:
+ num = BasePool._num_pool = BasePool._num_pool + 1
+ name = f"pool-{num}"
+
+ if num_workers < 1:
+ raise ValueError("num_workers must be at least 1")
+
+ self.conninfo = conninfo
+ self.kwargs: Dict[str, Any] = kwargs or {}
+ self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None]
+ self._reconnect_failed = reconnect_failed or (lambda pool: None)
+ self.name = name
+ self._min_size = min_size
+ self._max_size = max_size
+ self.timeout = timeout
+ self.max_waiting = max_waiting
+ self.reconnect_timeout = reconnect_timeout
+ self.max_lifetime = max_lifetime
+ self.max_idle = max_idle
+ self.num_workers = num_workers
+
+ self._nconns = min_size # currently in the pool, out, being prepared
+ self._pool = Deque[ConnectionType]()
+ self._stats = Counter[str]()
+
+ # Min number of connections in the pool in a max_idle unit of time.
+ # It is reset periodically by the ShrinkPool scheduled task.
+ # It is used to shrink back the pool if maxcon > min_size and extra
+ # connections have been acquired, if we notice that in the last
+ # max_idle interval they weren't all used.
+ self._nconns_min = min_size
+
+ # Flag to allow the pool to grow only one connection at time. In case
+ # of spike, if threads are allowed to grow in parallel and connection
+ # time is slow, there won't be any thread available to return the
+ # connections to the pool.
+ self._growing = False
+
+ self._opened = False
+ self._closed = True
+
+ def __repr__(self) -> str:
+ return (
+ f"<{self.__class__.__module__}.{self.__class__.__name__}"
+ f" {self.name!r} at 0x{id(self):x}>"
+ )
+
+ @property
+ def min_size(self) -> int:
+ return self._min_size
+
+ @property
+ def max_size(self) -> int:
+ return self._max_size
+
+ @property
+ def closed(self) -> bool:
+ """`!True` if the pool is closed."""
+ return self._closed
+
+ def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
+ if max_size is None:
+ max_size = min_size
+
+ if min_size < 0:
+ raise ValueError("min_size cannot be negative")
+ if max_size < min_size:
+ raise ValueError("max_size must be greater or equal than min_size")
+ if min_size == max_size == 0:
+ raise ValueError("if min_size is 0 max_size must be greater or than 0")
+
+ return min_size, max_size
+
+ def _check_open(self) -> None:
+ if self._closed and self._opened:
+ raise e.OperationalError(
+ "pool has already been opened/closed and cannot be reused"
+ )
+
+ def _check_open_getconn(self) -> None:
+ if self._closed:
+ if self._opened:
+ raise PoolClosed(f"the pool {self.name!r} is already closed")
+ else:
+ raise PoolClosed(f"the pool {self.name!r} is not open yet")
+
+ def _check_pool_putconn(self, conn: ConnectionType) -> None:
+ pool = getattr(conn, "_pool", None)
+ if pool is self:
+ return
+
+ if pool:
+ msg = f"it comes from pool {pool.name!r}"
+ else:
+ msg = "it doesn't come from any pool"
+ raise ValueError(
+ f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+ )
+
+ def get_stats(self) -> Dict[str, int]:
+ """
+ Return current stats about the pool usage.
+ """
+ rv = dict(self._stats)
+ rv.update(self._get_measures())
+ return rv
+
+ def pop_stats(self) -> Dict[str, int]:
+ """
+ Return current stats about the pool usage.
+
+ After the call, all the counters are reset to zero.
+ """
+ stats, self._stats = self._stats, Counter()
+ rv = dict(stats)
+ rv.update(self._get_measures())
+ return rv
+
+ def _get_measures(self) -> Dict[str, int]:
+ """
+ Return immediate measures of the pool (not counters).
+ """
+ return {
+ self._POOL_MIN: self._min_size,
+ self._POOL_MAX: self._max_size,
+ self._POOL_SIZE: self._nconns,
+ self._POOL_AVAILABLE: len(self._pool),
+ }
+
+ @classmethod
+ def _jitter(cls, value: float, min_pc: float, max_pc: float) -> float:
+ """
+ Add a random value to *value* between *min_pc* and *max_pc* percent.
+ """
+ return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc)
+
+ def _set_connection_expiry_date(self, conn: ConnectionType) -> None:
+ """Set an expiry date on a connection.
+
+ Add some randomness to avoid mass reconnection.
+ """
+ conn._expire_at = monotonic() + self._jitter(self.max_lifetime, -0.05, 0.0)
+
+
+class ConnectionAttempt:
+ """Keep the state of a connection attempt."""
+
+ INITIAL_DELAY = 1.0
+ DELAY_JITTER = 0.1
+ DELAY_BACKOFF = 2.0
+
+ def __init__(self, *, reconnect_timeout: float):
+ self.reconnect_timeout = reconnect_timeout
+ self.delay = 0.0
+ self.give_up_at = 0.0
+
+ def update_delay(self, now: float) -> None:
+ """Calculate how long to wait for a new connection attempt"""
+ if self.delay == 0.0:
+ self.give_up_at = now + self.reconnect_timeout
+ self.delay = BasePool._jitter(
+ self.INITIAL_DELAY, -self.DELAY_JITTER, self.DELAY_JITTER
+ )
+ else:
+ self.delay *= self.DELAY_BACKOFF
+
+ if self.delay + now > self.give_up_at:
+ self.delay = max(0.0, self.give_up_at - now)
+
+ def time_to_give_up(self, now: float) -> bool:
+ """Return True if we are tired of trying to connect. Meh."""
+ return self.give_up_at > 0.0 and now >= self.give_up_at
diff --git a/psycopg_pool/psycopg_pool/errors.py b/psycopg_pool/psycopg_pool/errors.py
new file mode 100644
index 0000000..9e672ad
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/errors.py
@@ -0,0 +1,25 @@
+"""
+Connection pool errors.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from psycopg import errors as e
+
+
+class PoolClosed(e.OperationalError):
+ """Attempt to get a connection from a closed pool."""
+
+ __module__ = "psycopg_pool"
+
+
+class PoolTimeout(e.OperationalError):
+ """The pool couldn't provide a connection in acceptable time."""
+
+ __module__ = "psycopg_pool"
+
+
+class TooManyRequests(e.OperationalError):
+ """Too many requests in the queue waiting for a connection from the pool."""
+
+ __module__ = "psycopg_pool"
diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py
new file mode 100644
index 0000000..c0a77c2
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/null_pool.py
@@ -0,0 +1,159 @@
+"""
+Psycopg null connection pools
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import logging
+import threading
+from typing import Any, Optional, Tuple
+
+from psycopg import Connection
+from psycopg.pq import TransactionStatus
+
+from .pool import ConnectionPool, AddConnection
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class _BaseNullConnectionPool:
+ def __init__(
+ self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any
+ ):
+ super().__init__( # type: ignore[call-arg]
+ conninfo, *args, min_size=min_size, **kwargs
+ )
+
+ def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
+ if max_size is None:
+ max_size = min_size
+
+ if min_size != 0:
+ raise ValueError("null pools must have min_size = 0")
+ if max_size < min_size:
+ raise ValueError("max_size must be greater or equal than min_size")
+
+ return min_size, max_size
+
+ def _start_initial_tasks(self) -> None:
+ # Null pools don't have background tasks to fill connections
+ # or to grow/shrink.
+ return
+
+ def _maybe_grow_pool(self) -> None:
+ # null pools don't grow
+ pass
+
+
+class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
+ def wait(self, timeout: float = 30.0) -> None:
+ """
+ Create a connection for test.
+
+ Calling this function will verify that the connectivity with the
+ database works as expected. However the connection will not be stored
+ in the pool.
+
+ Close the pool, and raise `PoolTimeout`, if not ready within *timeout*
+ sec.
+ """
+ self._check_open_getconn()
+
+ with self._lock:
+ assert not self._pool_full_event
+ self._pool_full_event = threading.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ self.run_task(AddConnection(self))
+ if not self._pool_full_event.wait(timeout):
+ self.close() # stop all the threads
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
+
+ with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[Connection[Any]]:
+ conn: Optional[Connection[Any]] = None
+ if self.max_size == 0 or self._nconns < self.max_size:
+ # Create a new connection for the client
+ try:
+ conn = self._connect(timeout=timeout)
+ except ConnectionTimeout as ex:
+ raise PoolTimeout(str(ex)) from None
+ self._nconns += 1
+
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+ with self._lock:
+ if not self._closed and self._waiting:
+ return False
+
+ conn._pool = None
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ conn.close()
+ self._nconns -= 1
+ return True
+
+ def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ """Change the size of the pool during runtime.
+
+ Only *max_size* can be changed; *min_size* must remain 0.
+ """
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+
+ def check(self) -> None:
+ """No-op, as the pool doesn't have connections in its state."""
+ pass
+
+ def _add_to_pool(self, conn: Connection[Any]) -> None:
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: close the connection
+ conn.close()
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is ready.
+ if self._pool_full_event:
+ self._pool_full_event.set()
+ else:
+ # The connection created by wait shouldn't decrease the
+ # count of the number of connection used.
+ self._nconns -= 1
diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py
new file mode 100644
index 0000000..ae9d207
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/null_pool_async.py
@@ -0,0 +1,122 @@
+"""
+psycopg asynchronous null connection pool
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import asyncio
+import logging
+from typing import Any, Optional
+
+from psycopg import AsyncConnection
+from psycopg.pq import TransactionStatus
+
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+from .null_pool import _BaseNullConnectionPool
+from .pool_async import AsyncConnectionPool, AddConnection
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
+ async def wait(self, timeout: float = 30.0) -> None:
+ self._check_open_getconn()
+
+ async with self._lock:
+ assert not self._pool_full_event
+ self._pool_full_event = asyncio.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ self.run_task(AddConnection(self))
+ try:
+ await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ await self.close() # stop all the tasks
+ raise PoolTimeout(
+ f"pool initialization incomplete after {timeout} sec"
+ ) from None
+
+ async with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ async def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[AsyncConnection[Any]]:
+ conn: Optional[AsyncConnection[Any]] = None
+ if self.max_size == 0 or self._nconns < self.max_size:
+ # Create a new connection for the client
+ try:
+ conn = await self._connect(timeout=timeout)
+ except ConnectionTimeout as ex:
+ raise PoolTimeout(str(ex)) from None
+ self._nconns += 1
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+ # Close the connection if no client is waiting for it, or if the pool
+ # is closed. For extra refcare remove the pool reference from it.
+ # Maintain the stats.
+ async with self._lock:
+ if not self._closed and self._waiting:
+ return False
+
+ conn._pool = None
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ await conn.close()
+ self._nconns -= 1
+ return True
+
+ async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ async with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+
+ async def check(self) -> None:
+ pass
+
+ async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ async with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if await pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: close the connection
+ await conn.close()
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is ready.
+ if self._pool_full_event:
+ self._pool_full_event.set()
+ else:
+ # The connection created by wait shouldn't decrease the
+ # count of the number of connection used.
+ self._nconns -= 1
diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py
new file mode 100644
index 0000000..609d95d
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/pool.py
@@ -0,0 +1,839 @@
+"""
+psycopg synchronous connection pool
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+import threading
+from abc import ABC, abstractmethod
+from time import monotonic
+from queue import Queue, Empty
+from types import TracebackType
+from typing import Any, Callable, Dict, Iterator, List
+from typing import Optional, Sequence, Type
+from weakref import ref
+from contextlib import contextmanager
+
+from psycopg import errors as e
+from psycopg import Connection
+from psycopg.pq import TransactionStatus
+
+from .base import ConnectionAttempt, BasePool
+from .sched import Scheduler
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
+from ._compat import Deque
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class ConnectionPool(BasePool[Connection[Any]]):
+ def __init__(
+ self,
+ conninfo: str = "",
+ *,
+ open: bool = True,
+ connection_class: Type[Connection[Any]] = Connection,
+ configure: Optional[Callable[[Connection[Any]], None]] = None,
+ reset: Optional[Callable[[Connection[Any]], None]] = None,
+ **kwargs: Any,
+ ):
+ self.connection_class = connection_class
+ self._configure = configure
+ self._reset = reset
+
+ self._lock = threading.RLock()
+ self._waiting = Deque["WaitingClient"]()
+
+ # to notify that the pool is full
+ self._pool_full_event: Optional[threading.Event] = None
+
+ self._sched = Scheduler()
+ self._sched_runner: Optional[threading.Thread] = None
+ self._tasks: "Queue[MaintenanceTask]" = Queue()
+ self._workers: List[threading.Thread] = []
+
+ super().__init__(conninfo, **kwargs)
+
+ if open:
+ self.open()
+
+ def __del__(self) -> None:
+ # If the '_closed' property is not set we probably failed in __init__.
+ # Don't try anything complicated as probably it won't work.
+ if getattr(self, "_closed", True):
+ return
+
+ self._stop_workers()
+
+ def wait(self, timeout: float = 30.0) -> None:
+ """
+ Wait for the pool to be full (with `min_size` connections) after creation.
+
+ Close the pool, and raise `PoolTimeout`, if not ready within *timeout*
+ sec.
+
+ Calling this method is not mandatory: you can try and use the pool
+ immediately after its creation. The first client will be served as soon
+ as a connection is ready. You can use this method if you prefer your
+ program to terminate in case the environment is not configured
+ properly, rather than trying to stay up the hardest it can.
+ """
+ self._check_open_getconn()
+
+ with self._lock:
+ assert not self._pool_full_event
+ if len(self._pool) >= self._min_size:
+ return
+ self._pool_full_event = threading.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ if not self._pool_full_event.wait(timeout):
+ self.close() # stop all the threads
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
+
+ with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ @contextmanager
+ def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]:
+ """Context manager to obtain a connection from the pool.
+
+ Return the connection immediately if available, otherwise wait up to
+ *timeout* or `self.timeout` seconds and throw `PoolTimeout` if a
+ connection is not available in time.
+
+ Upon context exit, return the connection to the pool. Apply the normal
+ :ref:`connection context behaviour <with-connection>` (commit/rollback
+ the transaction in case of success/error). If the connection is no more
+ in working state, replace it with a new one.
+ """
+ conn = self.getconn(timeout=timeout)
+ t0 = monotonic()
+ try:
+ with conn:
+ yield conn
+ finally:
+ t1 = monotonic()
+ self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
+ self.putconn(conn)
+
+ def getconn(self, timeout: Optional[float] = None) -> Connection[Any]:
+ """Obtain a connection from the pool.
+
+ You should preferably use `connection()`. Use this function only if
+ it is not possible to use the connection as context manager.
+
+ After using this function you *must* call a corresponding `putconn()`:
+ failing to do so will deplete the pool. A depleted pool is a sad pool:
+ you don't want a depleted pool.
+ """
+ logger.info("connection requested from %r", self.name)
+ self._stats[self._REQUESTS_NUM] += 1
+
+ # Critical section: decide here if there's a connection ready
+ # or if the client needs to wait.
+ with self._lock:
+ self._check_open_getconn()
+ conn = self._get_ready_connection(timeout)
+ if not conn:
+ # No connection available: put the client in the waiting queue
+ t0 = monotonic()
+ pos = WaitingClient()
+ self._waiting.append(pos)
+ self._stats[self._REQUESTS_QUEUED] += 1
+
+ # If there is space for the pool to grow, let's do it
+ self._maybe_grow_pool()
+
+ # If we are in the waiting queue, wait to be assigned a connection
+ # (outside the critical section, so only the waiting client is locked)
+ if not conn:
+ if timeout is None:
+ timeout = self.timeout
+ try:
+ conn = pos.wait(timeout=timeout)
+ except Exception:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise
+ finally:
+ t1 = monotonic()
+ self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+ # Tell the connection it belongs to a pool to avoid closing on __exit__
+ # Note that this property shouldn't be set while the connection is in
+ # the pool, to avoid to create a reference loop.
+ conn._pool = self
+ logger.info("connection given by %r", self.name)
+ return conn
+
+ def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[Connection[Any]]:
+ """Return a connection, if the client deserves one."""
+ conn: Optional[Connection[Any]] = None
+ if self._pool:
+ # Take a connection ready out of the pool
+ conn = self._pool.popleft()
+ if len(self._pool) < self._nconns_min:
+ self._nconns_min = len(self._pool)
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ def _maybe_grow_pool(self) -> None:
+ # Allow only one thread at time to grow the pool (or returning
+ # connections might be starved).
+ if self._nconns >= self._max_size or self._growing:
+ return
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self._growing = True
+ self.run_task(AddConnection(self, growing=True))
+
+ def putconn(self, conn: Connection[Any]) -> None:
+ """Return a connection to the loving hands of its pool.
+
+ Use this function only paired with a `getconn()`. You don't need to use
+ it if you use the much more comfortable `connection()` context manager.
+ """
+ # Quick check to discard the wrong connection
+ self._check_pool_putconn(conn)
+
+ logger.info("returning connection to %r", self.name)
+
+ if self._maybe_close_connection(conn):
+ return
+
+ # Use a worker to perform eventual maintenance work in a separate thread
+ if self._reset:
+ self.run_task(ReturnConnection(self, conn))
+ else:
+ self._return_connection(conn)
+
+ def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+ """Close a returned connection if necessary.
+
+ Return `!True if the connection was closed.
+ """
+ # If the pool is closed just close the connection instead of returning
+ # it to the pool. For extra refcare remove the pool reference from it.
+ if not self._closed:
+ return False
+
+ conn._pool = None
+ conn.close()
+ return True
+
+ def open(self, wait: bool = False, timeout: float = 30.0) -> None:
+ """Open the pool by starting connecting and and accepting clients.
+
+ If *wait* is `!False`, return immediately and let the background worker
+ fill the pool if `min_size` > 0. Otherwise wait up to *timeout* seconds
+ for the requested number of connections to be ready (see `wait()` for
+ details).
+
+ It is safe to call `!open()` again on a pool already open (because the
+ method was already called, or because the pool context was entered, or
+ because the pool was initialized with *open* = `!True`) but you cannot
+ currently re-open a closed pool.
+ """
+ with self._lock:
+ self._open()
+
+ if wait:
+ self.wait(timeout=timeout)
+
+ def _open(self) -> None:
+ if not self._closed:
+ return
+
+ self._check_open()
+
+ self._closed = False
+ self._opened = True
+
+ self._start_workers()
+ self._start_initial_tasks()
+
+ def _start_workers(self) -> None:
+ self._sched_runner = threading.Thread(
+ target=self._sched.run,
+ name=f"{self.name}-scheduler",
+ daemon=True,
+ )
+ assert not self._workers
+ for i in range(self.num_workers):
+ t = threading.Thread(
+ target=self.worker,
+ args=(self._tasks,),
+ name=f"{self.name}-worker-{i}",
+ daemon=True,
+ )
+ self._workers.append(t)
+
+ # The object state is complete. Start the worker threads
+ self._sched_runner.start()
+ for t in self._workers:
+ t.start()
+
+ def _start_initial_tasks(self) -> None:
+ # populate the pool with initial min_size connections in background
+ for i in range(self._nconns):
+ self.run_task(AddConnection(self))
+
+ # Schedule a task to shrink the pool if connections over min_size have
+ # remained unused.
+ self.schedule_task(ShrinkPool(self), self.max_idle)
+
+ def close(self, timeout: float = 5.0) -> None:
+ """Close the pool and make it unavailable to new clients.
+
+ All the waiting and future clients will fail to acquire a connection
+ with a `PoolClosed` exception. Currently used connections will not be
+ closed until returned to the pool.
+
+ Wait *timeout* seconds for threads to terminate their job, if positive.
+ If the timeout expires the pool is closed anyway, although it may raise
+ some warnings on exit.
+ """
+ if self._closed:
+ return
+
+ with self._lock:
+ self._closed = True
+ logger.debug("pool %r closed", self.name)
+
+ # Take waiting client and pool connections out of the state
+ waiting = list(self._waiting)
+ self._waiting.clear()
+ connections = list(self._pool)
+ self._pool.clear()
+
+ # Now that the flag _closed is set, getconn will fail immediately,
+ # putconn will just close the returned connection.
+ self._stop_workers(waiting, connections, timeout)
+
+ def _stop_workers(
+ self,
+ waiting_clients: Sequence["WaitingClient"] = (),
+ connections: Sequence[Connection[Any]] = (),
+ timeout: float = 0.0,
+ ) -> None:
+
+ # Stop the scheduler
+ self._sched.enter(0, None)
+
+ # Stop the worker threads
+ workers, self._workers = self._workers[:], []
+ for i in range(len(workers)):
+ self.run_task(StopWorker(self))
+
+ # Signal to eventual clients in the queue that business is closed.
+ for pos in waiting_clients:
+ pos.fail(PoolClosed(f"the pool {self.name!r} is closed"))
+
+ # Close the connections still in the pool
+ for conn in connections:
+ conn.close()
+
+ # Wait for the worker threads to terminate
+ assert self._sched_runner is not None
+ sched_runner, self._sched_runner = self._sched_runner, None
+ if timeout > 0:
+ for t in [sched_runner] + workers:
+ if not t.is_alive():
+ continue
+ t.join(timeout)
+ if t.is_alive():
+ logger.warning(
+ "couldn't stop thread %s in pool %r within %s seconds",
+ t,
+ self.name,
+ timeout,
+ )
+
+ def __enter__(self) -> "ConnectionPool":
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ """Change the size of the pool during runtime."""
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ ngrow = max(0, min_size - self._min_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+ self._nconns += ngrow
+
+ for i in range(ngrow):
+ self.run_task(AddConnection(self))
+
+ def check(self) -> None:
+ """Verify the state of the connections currently in the pool.
+
+ Test each connection: if it works return it to the pool, otherwise
+ dispose of it and create a new one.
+ """
+ with self._lock:
+ conns = list(self._pool)
+ self._pool.clear()
+
+ # Give a chance to the pool to grow if it has no connection.
+ # In case there are enough connection, or the pool is already
+ # growing, this is a no-op.
+ self._maybe_grow_pool()
+
+ while conns:
+ conn = conns.pop()
+ try:
+ conn.execute("SELECT 1")
+ if conn.pgconn.transaction_status == TransactionStatus.INTRANS:
+ conn.rollback()
+ except Exception:
+ self._stats[self._CONNECTIONS_LOST] += 1
+ logger.warning("discarding broken connection: %s", conn)
+ self.run_task(AddConnection(self))
+ else:
+ self._add_to_pool(conn)
+
+ def reconnect_failed(self) -> None:
+ """
+ Called when reconnection failed for longer than `reconnect_timeout`.
+ """
+ self._reconnect_failed(self)
+
+ def run_task(self, task: "MaintenanceTask") -> None:
+ """Run a maintenance task in a worker thread."""
+ self._tasks.put_nowait(task)
+
+ def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
+ """Run a maintenance task in a worker thread in the future."""
+ self._sched.enter(delay, task.tick)
+
+ _WORKER_TIMEOUT = 60.0
+
+ @classmethod
+ def worker(cls, q: "Queue[MaintenanceTask]") -> None:
+ """Runner to execute pending maintenance task.
+
+ The function is designed to run as a separate thread.
+
+ Block on the queue *q*, run a task received. Finish running if a
+ StopWorker is received.
+ """
+ # Don't make all the workers time out at the same moment
+ timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1)
+ while True:
+ # Use a timeout to make the wait interruptible
+ try:
+ task = q.get(timeout=timeout)
+ except Empty:
+ continue
+
+ if isinstance(task, StopWorker):
+ logger.debug(
+ "terminating working thread %s",
+ threading.current_thread().name,
+ )
+ return
+
+ # Run the task. Make sure don't die in the attempt.
+ try:
+ task.run()
+ except Exception as ex:
+ logger.warning(
+ "task run %s failed: %s: %s",
+ task,
+ ex.__class__.__name__,
+ ex,
+ )
+
+ def _connect(self, timeout: Optional[float] = None) -> Connection[Any]:
+ """Return a new connection configured for the pool."""
+ self._stats[self._CONNECTIONS_NUM] += 1
+ kwargs = self.kwargs
+ if timeout:
+ kwargs = kwargs.copy()
+ kwargs["connect_timeout"] = max(round(timeout), 1)
+ t0 = monotonic()
+ try:
+ conn: Connection[Any]
+ conn = self.connection_class.connect(self.conninfo, **kwargs)
+ except Exception:
+ self._stats[self._CONNECTIONS_ERRORS] += 1
+ raise
+ else:
+ t1 = monotonic()
+ self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0))
+
+ conn._pool = self
+
+ if self._configure:
+ self._configure(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by configure function"
+ f" {self._configure}: discarded"
+ )
+
+ # Set an expiry date, with some randomness to avoid mass reconnection
+ self._set_connection_expiry_date(conn)
+ return conn
+
+ def _add_connection(
+ self, attempt: Optional[ConnectionAttempt], growing: bool = False
+ ) -> None:
+ """Try to connect and add the connection to the pool.
+
+ If failed, reschedule a new attempt in the future for a few times, then
+ give up, decrease the pool connections number and call
+ `self.reconnect_failed()`.
+
+ """
+ now = monotonic()
+ if not attempt:
+ attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
+
+ try:
+ conn = self._connect()
+ except Exception as ex:
+ logger.warning(f"error connecting in {self.name!r}: {ex}")
+ if attempt.time_to_give_up(now):
+ logger.warning(
+ "reconnection attempt in pool %r failed after %s sec",
+ self.name,
+ self.reconnect_timeout,
+ )
+ with self._lock:
+ self._nconns -= 1
+ # If we have given up with a growing attempt, allow a new one.
+ if growing and self._growing:
+ self._growing = False
+ self.reconnect_failed()
+ else:
+ attempt.update_delay(now)
+ self.schedule_task(
+ AddConnection(self, attempt, growing=growing),
+ attempt.delay,
+ )
+ return
+
+ logger.info("adding new connection to the pool")
+ self._add_to_pool(conn)
+ if growing:
+ with self._lock:
+ # Keep on growing if the pool is not full yet, or if there are
+ # clients waiting and the pool can extend.
+ if self._nconns < self._min_size or (
+ self._nconns < self._max_size and self._waiting
+ ):
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self.run_task(AddConnection(self, growing=True))
+ else:
+ self._growing = False
+
+ def _return_connection(self, conn: Connection[Any]) -> None:
+ """
+ Return a connection to the pool after usage.
+ """
+ self._reset_connection(conn)
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ # Connection no more in working state: create a new one.
+ self.run_task(AddConnection(self))
+ logger.warning("discarding closed connection: %s", conn)
+ return
+
+ # Check if the connection is past its best before date
+ if conn._expire_at <= monotonic():
+ self.run_task(AddConnection(self))
+ logger.info("discarding expired connection")
+ conn.close()
+ return
+
+ self._add_to_pool(conn)
+
+ def _add_to_pool(self, conn: Connection[Any]) -> None:
+ """
+ Add a connection to the pool.
+
+ The connection can be a fresh one or one already used in the pool.
+
+ If a client is already waiting for a connection pass it on, otherwise
+ put it back into the pool
+ """
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: put it back into the pool
+ self._pool.append(conn)
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is full.
+ if self._pool_full_event and len(self._pool) >= self._min_size:
+ self._pool_full_event.set()
+
+ def _reset_connection(self, conn: Connection[Any]) -> None:
+ """
+ Bring a connection to IDLE state or close it.
+ """
+ status = conn.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ pass
+
+ elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+ # Connection returned with an active transaction
+ logger.warning("rolling back returned connection: %s", conn)
+ try:
+ conn.rollback()
+ except Exception as ex:
+ logger.warning(
+ "rollback failed: %s: %s. Discarding connection %s",
+ ex.__class__.__name__,
+ ex,
+ conn,
+ )
+ conn.close()
+
+ elif status == TransactionStatus.ACTIVE:
+ # Connection returned during an operation. Bad... just close it.
+ logger.warning("closing returned connection: %s", conn)
+ conn.close()
+
+ if not conn.closed and self._reset:
+ try:
+ self._reset(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by reset function"
+ f" {self._reset}: discarded"
+ )
+ except Exception as ex:
+ logger.warning(f"error resetting connection: {ex}")
+ conn.close()
+
+ def _shrink_pool(self) -> None:
+ to_close: Optional[Connection[Any]] = None
+
+ with self._lock:
+ # Reset the min number of connections used
+ nconns_min = self._nconns_min
+ self._nconns_min = len(self._pool)
+
+ # If the pool can shrink and connections were unused, drop one
+ if self._nconns > self._min_size and nconns_min > 0:
+ to_close = self._pool.popleft()
+ self._nconns -= 1
+ self._nconns_min -= 1
+
+ if to_close:
+ logger.info(
+ "shrinking pool %r to %s because %s unused connections"
+ " in the last %s sec",
+ self.name,
+ self._nconns,
+ nconns_min,
+ self.max_idle,
+ )
+ to_close.close()
+
+ def _get_measures(self) -> Dict[str, int]:
+ rv = super()._get_measures()
+ rv[self._REQUESTS_WAITING] = len(self._waiting)
+ return rv
+
+
+class WaitingClient:
+ """A position in a queue for a client waiting for a connection."""
+
+ __slots__ = ("conn", "error", "_cond")
+
+ def __init__(self) -> None:
+ self.conn: Optional[Connection[Any]] = None
+ self.error: Optional[Exception] = None
+
+ # The WaitingClient behaves in a way similar to an Event, but we need
+ # to notify reliably the flagger that the waiter has "accepted" the
+ # message and it hasn't timed out yet, otherwise the pool may give a
+ # connection to a client that has already timed out getconn(), which
+ # will be lost.
+ self._cond = threading.Condition()
+
+ def wait(self, timeout: float) -> Connection[Any]:
+ """Wait for a connection to be set and return it.
+
+ Raise an exception if the wait times out or if fail() is called.
+ """
+ with self._cond:
+ if not (self.conn or self.error):
+ if not self._cond.wait(timeout):
+ self.error = PoolTimeout(
+ f"couldn't get a connection after {timeout} sec"
+ )
+
+ if self.conn:
+ return self.conn
+ else:
+ assert self.error
+ raise self.error
+
+ def set(self, conn: Connection[Any]) -> bool:
+ """Signal the client waiting that a connection is ready.
+
+ Return True if the client has "accepted" the connection, False
+ otherwise (typically because wait() has timed out).
+ """
+ with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.conn = conn
+ self._cond.notify_all()
+ return True
+
+ def fail(self, error: Exception) -> bool:
+ """Signal the client that, alas, they won't have a connection today.
+
+ Return True if the client has "accepted" the error, False otherwise
+ (typically because wait() has timed out).
+ """
+ with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.error = error
+ self._cond.notify_all()
+ return True
+
+
+class MaintenanceTask(ABC):
+ """A task to run asynchronously to maintain the pool state."""
+
+ def __init__(self, pool: "ConnectionPool"):
+ self.pool = ref(pool)
+
+ def __repr__(self) -> str:
+ pool = self.pool()
+ name = repr(pool.name) if pool else "<pool is gone>"
+ return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
+
+ def run(self) -> None:
+ """Run the task.
+
+ This usually happens in a worker thread. Call the concrete _run()
+ implementation, if the pool is still alive.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task run discarded: %s", self)
+ return
+
+ logger.debug("task running in %s: %s", threading.current_thread().name, self)
+ self._run(pool)
+
+ def tick(self) -> None:
+ """Run the scheduled task
+
+ This function is called by the scheduler thread. Use a worker to
+ run the task for real in order to free the scheduler immediately.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task tick discarded: %s", self)
+ return
+
+ pool.run_task(self)
+
+ @abstractmethod
+ def _run(self, pool: "ConnectionPool") -> None:
+ ...
+
+
+class StopWorker(MaintenanceTask):
+ """Signal the maintenance thread to terminate."""
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ pass
+
+
+class AddConnection(MaintenanceTask):
+ def __init__(
+ self,
+ pool: "ConnectionPool",
+ attempt: Optional["ConnectionAttempt"] = None,
+ growing: bool = False,
+ ):
+ super().__init__(pool)
+ self.attempt = attempt
+ self.growing = growing
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ pool._add_connection(self.attempt, growing=self.growing)
+
+
+class ReturnConnection(MaintenanceTask):
+ """Clean up and return a connection to the pool."""
+
+ def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"):
+ super().__init__(pool)
+ self.conn = conn
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ pool._return_connection(self.conn)
+
+
+class ShrinkPool(MaintenanceTask):
+ """If the pool can shrink, remove one connection.
+
+ Re-schedule periodically and also reset the minimum number of connections
+ in the pool.
+ """
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ # Reschedule the task now so that in case of any error we don't lose
+ # the periodic run.
+ pool.schedule_task(self, pool.max_idle)
+ pool._shrink_pool()
diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py
new file mode 100644
index 0000000..0ea6e9a
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/pool_async.py
@@ -0,0 +1,784 @@
+"""
+psycopg asynchronous connection pool
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import asyncio
+import logging
+from abc import ABC, abstractmethod
+from time import monotonic
+from types import TracebackType
+from typing import Any, AsyncIterator, Awaitable, Callable
+from typing import Dict, List, Optional, Sequence, Type
+from weakref import ref
+from contextlib import asynccontextmanager
+
+from psycopg import errors as e
+from psycopg import AsyncConnection
+from psycopg.pq import TransactionStatus
+
+from .base import ConnectionAttempt, BasePool
+from .sched import AsyncScheduler
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
+from ._compat import Task, create_task, Deque
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
+ def __init__(
+ self,
+ conninfo: str = "",
+ *,
+ open: bool = True,
+ connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
+ configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+ reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+ **kwargs: Any,
+ ):
+ self.connection_class = connection_class
+ self._configure = configure
+ self._reset = reset
+
+ # asyncio objects, created on open to attach them to the right loop.
+ self._lock: asyncio.Lock
+ self._sched: AsyncScheduler
+ self._tasks: "asyncio.Queue[MaintenanceTask]"
+
+ self._waiting = Deque["AsyncClient"]()
+
+ # to notify that the pool is full
+ self._pool_full_event: Optional[asyncio.Event] = None
+
+ self._sched_runner: Optional[Task[None]] = None
+ self._workers: List[Task[None]] = []
+
+ super().__init__(conninfo, **kwargs)
+
+ if open:
+ self._open()
+
+ async def wait(self, timeout: float = 30.0) -> None:
+ self._check_open_getconn()
+
+ async with self._lock:
+ assert not self._pool_full_event
+ if len(self._pool) >= self._min_size:
+ return
+ self._pool_full_event = asyncio.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ try:
+ await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ await self.close() # stop all the tasks
+ raise PoolTimeout(
+ f"pool initialization incomplete after {timeout} sec"
+ ) from None
+
+ async with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ @asynccontextmanager
+ async def connection(
+ self, timeout: Optional[float] = None
+ ) -> AsyncIterator[AsyncConnection[Any]]:
+ conn = await self.getconn(timeout=timeout)
+ t0 = monotonic()
+ try:
+ async with conn:
+ yield conn
+ finally:
+ t1 = monotonic()
+ self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
+ await self.putconn(conn)
+
+ async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+ logger.info("connection requested from %r", self.name)
+ self._stats[self._REQUESTS_NUM] += 1
+
+ self._check_open_getconn()
+
+ # Critical section: decide here if there's a connection ready
+ # or if the client needs to wait.
+ async with self._lock:
+ conn = await self._get_ready_connection(timeout)
+ if not conn:
+ # No connection available: put the client in the waiting queue
+ t0 = monotonic()
+ pos = AsyncClient()
+ self._waiting.append(pos)
+ self._stats[self._REQUESTS_QUEUED] += 1
+
+ # If there is space for the pool to grow, let's do it
+ self._maybe_grow_pool()
+
+ # If we are in the waiting queue, wait to be assigned a connection
+ # (outside the critical section, so only the waiting client is locked)
+ if not conn:
+ if timeout is None:
+ timeout = self.timeout
+ try:
+ conn = await pos.wait(timeout=timeout)
+ except Exception:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise
+ finally:
+ t1 = monotonic()
+ self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+ # Tell the connection it belongs to a pool to avoid closing on __exit__
+ # Note that this property shouldn't be set while the connection is in
+ # the pool, to avoid to create a reference loop.
+ conn._pool = self
+ logger.info("connection given by %r", self.name)
+ return conn
+
+ async def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[AsyncConnection[Any]]:
+ conn: Optional[AsyncConnection[Any]] = None
+ if self._pool:
+ # Take a connection ready out of the pool
+ conn = self._pool.popleft()
+ if len(self._pool) < self._nconns_min:
+ self._nconns_min = len(self._pool)
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ def _maybe_grow_pool(self) -> None:
+ # Allow only one task at time to grow the pool (or returning
+ # connections might be starved).
+ if self._nconns < self._max_size and not self._growing:
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self._growing = True
+ self.run_task(AddConnection(self, growing=True))
+
+ async def putconn(self, conn: AsyncConnection[Any]) -> None:
+ self._check_pool_putconn(conn)
+
+ logger.info("returning connection to %r", self.name)
+ if await self._maybe_close_connection(conn):
+ return
+
+ # Use a worker to perform eventual maintenance work in a separate task
+ if self._reset:
+ self.run_task(ReturnConnection(self, conn))
+ else:
+ await self._return_connection(conn)
+
+ async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+ # If the pool is closed just close the connection instead of returning
+ # it to the pool. For extra refcare remove the pool reference from it.
+ if not self._closed:
+ return False
+
+ conn._pool = None
+ await conn.close()
+ return True
+
+ async def open(self, wait: bool = False, timeout: float = 30.0) -> None:
+ # Make sure the lock is created after there is an event loop
+ try:
+ self._lock
+ except AttributeError:
+ self._lock = asyncio.Lock()
+
+ async with self._lock:
+ self._open()
+
+ if wait:
+ await self.wait(timeout=timeout)
+
+ def _open(self) -> None:
+ if not self._closed:
+ return
+
+ # Throw a RuntimeError if the pool is open outside a running loop.
+ asyncio.get_running_loop()
+
+ self._check_open()
+
+ # Create these objects now to attach them to the right loop.
+ # See #219
+ self._tasks = asyncio.Queue()
+ self._sched = AsyncScheduler()
+ # This has been most likely, but not necessarily, created in `open()`.
+ try:
+ self._lock
+ except AttributeError:
+ self._lock = asyncio.Lock()
+
+ self._closed = False
+ self._opened = True
+
+ self._start_workers()
+ self._start_initial_tasks()
+
+ def _start_workers(self) -> None:
+ self._sched_runner = create_task(
+ self._sched.run(), name=f"{self.name}-scheduler"
+ )
+ for i in range(self.num_workers):
+ t = create_task(
+ self.worker(self._tasks),
+ name=f"{self.name}-worker-{i}",
+ )
+ self._workers.append(t)
+
+ def _start_initial_tasks(self) -> None:
+ # populate the pool with initial min_size connections in background
+ for i in range(self._nconns):
+ self.run_task(AddConnection(self))
+
+ # Schedule a task to shrink the pool if connections over min_size have
+ # remained unused.
+ self.run_task(Schedule(self, ShrinkPool(self), self.max_idle))
+
+ async def close(self, timeout: float = 5.0) -> None:
+ if self._closed:
+ return
+
+ async with self._lock:
+ self._closed = True
+ logger.debug("pool %r closed", self.name)
+
+ # Take waiting client and pool connections out of the state
+ waiting = list(self._waiting)
+ self._waiting.clear()
+ connections = list(self._pool)
+ self._pool.clear()
+
+ # Now that the flag _closed is set, getconn will fail immediately,
+ # putconn will just close the returned connection.
+ await self._stop_workers(waiting, connections, timeout)
+
+ async def _stop_workers(
+ self,
+ waiting_clients: Sequence["AsyncClient"] = (),
+ connections: Sequence[AsyncConnection[Any]] = (),
+ timeout: float = 0.0,
+ ) -> None:
+ # Stop the scheduler
+ await self._sched.enter(0, None)
+
+ # Stop the worker tasks
+ workers, self._workers = self._workers[:], []
+ for w in workers:
+ self.run_task(StopWorker(self))
+
+ # Signal to eventual clients in the queue that business is closed.
+ for pos in waiting_clients:
+ await pos.fail(PoolClosed(f"the pool {self.name!r} is closed"))
+
+ # Close the connections still in the pool
+ for conn in connections:
+ await conn.close()
+
+ # Wait for the worker tasks to terminate
+ assert self._sched_runner is not None
+ sched_runner, self._sched_runner = self._sched_runner, None
+ wait = asyncio.gather(sched_runner, *workers)
+ try:
+ if timeout > 0:
+ await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
+ else:
+ await wait
+ except asyncio.TimeoutError:
+ logger.warning(
+ "couldn't stop pool %r tasks within %s seconds",
+ self.name,
+ timeout,
+ )
+
+ async def __aenter__(self) -> "AsyncConnectionPool":
+ await self.open()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ ngrow = max(0, min_size - self._min_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ async with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+ self._nconns += ngrow
+
+ for i in range(ngrow):
+ self.run_task(AddConnection(self))
+
+ async def check(self) -> None:
+ async with self._lock:
+ conns = list(self._pool)
+ self._pool.clear()
+
+ # Give a chance to the pool to grow if it has no connection.
+ # In case there are enough connection, or the pool is already
+ # growing, this is a no-op.
+ self._maybe_grow_pool()
+
+ while conns:
+ conn = conns.pop()
+ try:
+ await conn.execute("SELECT 1")
+ if conn.pgconn.transaction_status == TransactionStatus.INTRANS:
+ await conn.rollback()
+ except Exception:
+ self._stats[self._CONNECTIONS_LOST] += 1
+ logger.warning("discarding broken connection: %s", conn)
+ self.run_task(AddConnection(self))
+ else:
+ await self._add_to_pool(conn)
+
+ def reconnect_failed(self) -> None:
+ """
+ Called when reconnection failed for longer than `reconnect_timeout`.
+ """
+ self._reconnect_failed(self)
+
+ def run_task(self, task: "MaintenanceTask") -> None:
+ """Run a maintenance task in a worker."""
+ self._tasks.put_nowait(task)
+
+ async def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
+ """Run a maintenance task in a worker in the future."""
+ await self._sched.enter(delay, task.tick)
+
+ @classmethod
+ async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None:
+ """Runner to execute pending maintenance task.
+
+ The function is designed to run as a task.
+
+ Block on the queue *q*, run a task received. Finish running if a
+ StopWorker is received.
+ """
+ while True:
+ task = await q.get()
+
+ if isinstance(task, StopWorker):
+ logger.debug("terminating working task")
+ return
+
+ # Run the task. Make sure don't die in the attempt.
+ try:
+ await task.run()
+ except Exception as ex:
+ logger.warning(
+ "task run %s failed: %s: %s",
+ task,
+ ex.__class__.__name__,
+ ex,
+ )
+
+ async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+ self._stats[self._CONNECTIONS_NUM] += 1
+ kwargs = self.kwargs
+ if timeout:
+ kwargs = kwargs.copy()
+ kwargs["connect_timeout"] = max(round(timeout), 1)
+ t0 = monotonic()
+ try:
+ conn: AsyncConnection[Any]
+ conn = await self.connection_class.connect(self.conninfo, **kwargs)
+ except Exception:
+ self._stats[self._CONNECTIONS_ERRORS] += 1
+ raise
+ else:
+ t1 = monotonic()
+ self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0))
+
+ conn._pool = self
+
+ if self._configure:
+ await self._configure(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by configure function"
+ f" {self._configure}: discarded"
+ )
+
+ # Set an expiry date, with some randomness to avoid mass reconnection
+ self._set_connection_expiry_date(conn)
+ return conn
+
+ async def _add_connection(
+ self, attempt: Optional[ConnectionAttempt], growing: bool = False
+ ) -> None:
+ """Try to connect and add the connection to the pool.
+
+ If failed, reschedule a new attempt in the future for a few times, then
+ give up, decrease the pool connections number and call
+ `self.reconnect_failed()`.
+
+ """
+ now = monotonic()
+ if not attempt:
+ attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
+
+ try:
+ conn = await self._connect()
+ except Exception as ex:
+ logger.warning(f"error connecting in {self.name!r}: {ex}")
+ if attempt.time_to_give_up(now):
+ logger.warning(
+ "reconnection attempt in pool %r failed after %s sec",
+ self.name,
+ self.reconnect_timeout,
+ )
+ async with self._lock:
+ self._nconns -= 1
+ # If we have given up with a growing attempt, allow a new one.
+ if growing and self._growing:
+ self._growing = False
+ self.reconnect_failed()
+ else:
+ attempt.update_delay(now)
+ await self.schedule_task(
+ AddConnection(self, attempt, growing=growing),
+ attempt.delay,
+ )
+ return
+
+ logger.info("adding new connection to the pool")
+ await self._add_to_pool(conn)
+ if growing:
+ async with self._lock:
+ # Keep on growing if the pool is not full yet, or if there are
+ # clients waiting and the pool can extend.
+ if self._nconns < self._min_size or (
+ self._nconns < self._max_size and self._waiting
+ ):
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self.run_task(AddConnection(self, growing=True))
+ else:
+ self._growing = False
+
+ async def _return_connection(self, conn: AsyncConnection[Any]) -> None:
+ """
+ Return a connection to the pool after usage.
+ """
+ await self._reset_connection(conn)
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ # Connection no more in working state: create a new one.
+ self.run_task(AddConnection(self))
+ logger.warning("discarding closed connection: %s", conn)
+ return
+
+ # Check if the connection is past its best before date
+ if conn._expire_at <= monotonic():
+ self.run_task(AddConnection(self))
+ logger.info("discarding expired connection")
+ await conn.close()
+ return
+
+ await self._add_to_pool(conn)
+
+ async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+ """
+ Add a connection to the pool.
+
+ The connection can be a fresh one or one already used in the pool.
+
+ If a client is already waiting for a connection pass it on, otherwise
+ put it back into the pool
+ """
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ async with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if await pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: put it back into the pool
+ self._pool.append(conn)
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is full.
+ if self._pool_full_event and len(self._pool) >= self._min_size:
+ self._pool_full_event.set()
+
+ async def _reset_connection(self, conn: AsyncConnection[Any]) -> None:
+ """
+ Bring a connection to IDLE state or close it.
+ """
+ status = conn.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ pass
+
+ elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+ # Connection returned with an active transaction
+ logger.warning("rolling back returned connection: %s", conn)
+ try:
+ await conn.rollback()
+ except Exception as ex:
+ logger.warning(
+ "rollback failed: %s: %s. Discarding connection %s",
+ ex.__class__.__name__,
+ ex,
+ conn,
+ )
+ await conn.close()
+
+ elif status == TransactionStatus.ACTIVE:
+ # Connection returned during an operation. Bad... just close it.
+ logger.warning("closing returned connection: %s", conn)
+ await conn.close()
+
+ if not conn.closed and self._reset:
+ try:
+ await self._reset(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by reset function"
+ f" {self._reset}: discarded"
+ )
+ except Exception as ex:
+ logger.warning(f"error resetting connection: {ex}")
+ await conn.close()
+
+ async def _shrink_pool(self) -> None:
+ to_close: Optional[AsyncConnection[Any]] = None
+
+ async with self._lock:
+ # Reset the min number of connections used
+ nconns_min = self._nconns_min
+ self._nconns_min = len(self._pool)
+
+ # If the pool can shrink and connections were unused, drop one
+ if self._nconns > self._min_size and nconns_min > 0:
+ to_close = self._pool.popleft()
+ self._nconns -= 1
+ self._nconns_min -= 1
+
+ if to_close:
+ logger.info(
+ "shrinking pool %r to %s because %s unused connections"
+ " in the last %s sec",
+ self.name,
+ self._nconns,
+ nconns_min,
+ self.max_idle,
+ )
+ await to_close.close()
+
+ def _get_measures(self) -> Dict[str, int]:
+ rv = super()._get_measures()
+ rv[self._REQUESTS_WAITING] = len(self._waiting)
+ return rv
+
+
+class AsyncClient:
+ """A position in a queue for a client waiting for a connection."""
+
+ __slots__ = ("conn", "error", "_cond")
+
+ def __init__(self) -> None:
+ self.conn: Optional[AsyncConnection[Any]] = None
+ self.error: Optional[Exception] = None
+
+ # The AsyncClient behaves in a way similar to an Event, but we need
+ # to notify reliably the flagger that the waiter has "accepted" the
+ # message and it hasn't timed out yet, otherwise the pool may give a
+ # connection to a client that has already timed out getconn(), which
+ # will be lost.
+ self._cond = asyncio.Condition()
+
+ async def wait(self, timeout: float) -> AsyncConnection[Any]:
+ """Wait for a connection to be set and return it.
+
+ Raise an exception if the wait times out or if fail() is called.
+ """
+ async with self._cond:
+ if not (self.conn or self.error):
+ try:
+ await asyncio.wait_for(self._cond.wait(), timeout)
+ except asyncio.TimeoutError:
+ self.error = PoolTimeout(
+ f"couldn't get a connection after {timeout} sec"
+ )
+
+ if self.conn:
+ return self.conn
+ else:
+ assert self.error
+ raise self.error
+
+ async def set(self, conn: AsyncConnection[Any]) -> bool:
+ """Signal the client waiting that a connection is ready.
+
+ Return True if the client has "accepted" the connection, False
+ otherwise (typically because wait() has timed out).
+ """
+ async with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.conn = conn
+ self._cond.notify_all()
+ return True
+
+ async def fail(self, error: Exception) -> bool:
+ """Signal the client that, alas, they won't have a connection today.
+
+ Return True if the client has "accepted" the error, False otherwise
+ (typically because wait() has timed out).
+ """
+ async with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.error = error
+ self._cond.notify_all()
+ return True
+
+
+class MaintenanceTask(ABC):
+ """A task to run asynchronously to maintain the pool state."""
+
+ def __init__(self, pool: "AsyncConnectionPool"):
+ self.pool = ref(pool)
+
+ def __repr__(self) -> str:
+ pool = self.pool()
+ name = repr(pool.name) if pool else "<pool is gone>"
+ return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
+
+ async def run(self) -> None:
+ """Run the task.
+
+ This usually happens in a worker. Call the concrete _run()
+ implementation, if the pool is still alive.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task run discarded: %s", self)
+ return
+
+ await self._run(pool)
+
+ async def tick(self) -> None:
+ """Run the scheduled task
+
+ This function is called by the scheduler task. Use a worker to
+ run the task for real in order to free the scheduler immediately.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task tick discarded: %s", self)
+ return
+
+ pool.run_task(self)
+
+ @abstractmethod
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ ...
+
+
+class StopWorker(MaintenanceTask):
+ """Signal the maintenance worker to terminate."""
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ pass
+
+
+class AddConnection(MaintenanceTask):
+ def __init__(
+ self,
+ pool: "AsyncConnectionPool",
+ attempt: Optional["ConnectionAttempt"] = None,
+ growing: bool = False,
+ ):
+ super().__init__(pool)
+ self.attempt = attempt
+ self.growing = growing
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ await pool._add_connection(self.attempt, growing=self.growing)
+
+
+class ReturnConnection(MaintenanceTask):
+ """Clean up and return a connection to the pool."""
+
+ def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"):
+ super().__init__(pool)
+ self.conn = conn
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ await pool._return_connection(self.conn)
+
+
+class ShrinkPool(MaintenanceTask):
+ """If the pool can shrink, remove one connection.
+
+ Re-schedule periodically and also reset the minimum number of connections
+ in the pool.
+ """
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ # Reschedule the task now so that in case of any error we don't lose
+ # the periodic run.
+ await pool.schedule_task(self, pool.max_idle)
+ await pool._shrink_pool()
+
+
+class Schedule(MaintenanceTask):
+ """Schedule a task in the pool scheduler.
+
+ This task is a trampoline to allow to use a sync call (pool.run_task)
+ to execute an async one (pool.schedule_task).
+ """
+
+ def __init__(
+ self,
+ pool: "AsyncConnectionPool",
+ task: MaintenanceTask,
+ delay: float,
+ ):
+ super().__init__(pool)
+ self.task = task
+ self.delay = delay
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ await pool.schedule_task(self.task, self.delay)
diff --git a/psycopg_pool/psycopg_pool/py.typed b/psycopg_pool/psycopg_pool/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/py.typed
diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py
new file mode 100644
index 0000000..ca26007
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/sched.py
@@ -0,0 +1,177 @@
+"""
+A minimal scheduler to schedule tasks run in the future.
+
+Inspired to the standard library `sched.scheduler`, but designed for
+multi-thread usage ground up, not as an afterthought. Tasks can be scheduled in
+front of the one currently running and `Scheduler.run()` can be left running
+without any task scheduled.
+
+Tasks are called "Task", not "Event", here, because we actually make use of
+`threading.Event` and the two would be confusing.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import asyncio
+import logging
+import threading
+from time import monotonic
+from heapq import heappush, heappop
+from typing import Any, Callable, List, Optional, NamedTuple
+
+logger = logging.getLogger(__name__)
+
+
+class Task(NamedTuple):
+ time: float
+ action: Optional[Callable[[], Any]]
+
+ def __eq__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time == other.time
+
+ def __lt__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time < other.time
+
+ def __le__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time <= other.time
+
+ def __gt__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time > other.time
+
+ def __ge__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time >= other.time
+
+
+class Scheduler:
+ def __init__(self) -> None:
+ """Initialize a new instance, passing the time and delay functions."""
+ self._queue: List[Task] = []
+ self._lock = threading.RLock()
+ self._event = threading.Event()
+
+ EMPTY_QUEUE_TIMEOUT = 600.0
+
+ def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue delayed in the future.
+
+ Schedule a `!None` to stop the execution.
+ """
+ time = monotonic() + delay
+ return self.enterabs(time, action)
+
+ def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue at an absolute time.
+
+ Schedule a `!None` to stop the execution.
+ """
+ task = Task(time, action)
+ with self._lock:
+ heappush(self._queue, task)
+ first = self._queue[0] is task
+
+ if first:
+ self._event.set()
+
+ return task
+
+ def run(self) -> None:
+ """Execute the events scheduled."""
+ q = self._queue
+ while True:
+ with self._lock:
+ now = monotonic()
+ task = q[0] if q else None
+ if task:
+ if task.time <= now:
+ heappop(q)
+ else:
+ delay = task.time - now
+ task = None
+ else:
+ delay = self.EMPTY_QUEUE_TIMEOUT
+ self._event.clear()
+
+ if task:
+ if not task.action:
+ break
+ try:
+ task.action()
+ except Exception as e:
+ logger.warning(
+ "scheduled task run %s failed: %s: %s",
+ task.action,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ # Block for the expected timeout or until a new task scheduled
+ self._event.wait(timeout=delay)
+
+
+class AsyncScheduler:
+ def __init__(self) -> None:
+ """Initialize a new instance, passing the time and delay functions."""
+ self._queue: List[Task] = []
+ self._lock = asyncio.Lock()
+ self._event = asyncio.Event()
+
+ EMPTY_QUEUE_TIMEOUT = 600.0
+
+ async def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue delayed in the future.
+
+ Schedule a `!None` to stop the execution.
+ """
+ time = monotonic() + delay
+ return await self.enterabs(time, action)
+
+ async def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue at an absolute time.
+
+ Schedule a `!None` to stop the execution.
+ """
+ task = Task(time, action)
+ async with self._lock:
+ heappush(self._queue, task)
+ first = self._queue[0] is task
+
+ if first:
+ self._event.set()
+
+ return task
+
+ async def run(self) -> None:
+ """Execute the events scheduled."""
+ q = self._queue
+ while True:
+ async with self._lock:
+ now = monotonic()
+ task = q[0] if q else None
+ if task:
+ if task.time <= now:
+ heappop(q)
+ else:
+ delay = task.time - now
+ task = None
+ else:
+ delay = self.EMPTY_QUEUE_TIMEOUT
+ self._event.clear()
+
+ if task:
+ if not task.action:
+ break
+ try:
+ await task.action()
+ except Exception as e:
+ logger.warning(
+ "scheduled task run %s failed: %s: %s",
+ task.action,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ # Block for the expected timeout or until a new task scheduled
+ try:
+ await asyncio.wait_for(self._event.wait(), delay)
+ except asyncio.TimeoutError:
+ pass
diff --git a/psycopg_pool/psycopg_pool/version.py b/psycopg_pool/psycopg_pool/version.py
new file mode 100644
index 0000000..fc99bbd
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/version.py
@@ -0,0 +1,13 @@
+"""
+psycopg pool version file.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+# Use a versioning scheme as defined in
+# https://www.python.org/dev/peps/pep-0440/
+
+# STOP AND READ! if you change:
+__version__ = "3.1.5"
+# also change:
+# - `docs/news_pool.rst` to declare this version current or unreleased
diff --git a/psycopg_pool/setup.cfg b/psycopg_pool/setup.cfg
new file mode 100644
index 0000000..1a3274e
--- /dev/null
+++ b/psycopg_pool/setup.cfg
@@ -0,0 +1,45 @@
+[metadata]
+name = psycopg-pool
+description = Connection Pool for Psycopg
+url = https://psycopg.org/psycopg3/
+author = Daniele Varrazzo
+author_email = daniele.varrazzo@gmail.com
+license = GNU Lesser General Public License v3 (LGPLv3)
+
+project_urls =
+ Homepage = https://psycopg.org/
+ Code = https://github.com/psycopg/psycopg
+ Issue Tracker = https://github.com/psycopg/psycopg/issues
+ Download = https://pypi.org/project/psycopg-pool/
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
+ Operating System :: MacOS :: MacOS X
+ Operating System :: Microsoft :: Windows
+ Operating System :: POSIX
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Topic :: Database
+ Topic :: Database :: Front-Ends
+ Topic :: Software Development
+ Topic :: Software Development :: Libraries :: Python Modules
+
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+license_files = LICENSE.txt
+
+[options]
+python_requires = >= 3.7
+packages = find:
+zip_safe = False
+install_requires =
+ typing-extensions >= 3.10
+
+[options.package_data]
+psycopg_pool = py.typed
diff --git a/psycopg_pool/setup.py b/psycopg_pool/setup.py
new file mode 100644
index 0000000..771847d
--- /dev/null
+++ b/psycopg_pool/setup.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+"""
+PostgreSQL database adapter for Python - Connection Pool
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+from setuptools import setup
+
+# Move to the directory of setup.py: executing this file from another location
+# (e.g. from the project root) will fail
+here = os.path.abspath(os.path.dirname(__file__))
+if os.path.abspath(os.getcwd()) != here:
+ os.chdir(here)
+
+with open("psycopg_pool/version.py") as f:
+ data = f.read()
+ m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data)
+ if not m:
+ raise Exception(f"cannot find version in {f.name}")
+ version = m.group(1)
+
+
+setup(version=version)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..14f3c9e
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,55 @@
+[build-system]
+requires = ["setuptools>=49.2.0", "wheel>=0.37"]
+build-backend = "setuptools.build_meta"
+
+[tool.pytest.ini_options]
+asyncio_mode = "auto"
+filterwarnings = [
+ "error",
+]
+testpaths=[
+ "tests",
+]
+# Note: On Travis they these options seem to leak objects
+# log_format = "%(asctime)s.%(msecs)03d %(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
+# log_level = "DEBUG"
+
+[tool.coverage.run]
+source = [
+ "psycopg/psycopg",
+ "psycopg_pool/psycopg_pool",
+]
+[tool.coverage.report]
+exclude_lines = [
+ "if TYPE_CHECKING:",
+ '\.\.\.$',
+]
+
+[tool.mypy]
+files = [
+ "psycopg/psycopg",
+ "psycopg_pool/psycopg_pool",
+ "psycopg_c/psycopg_c",
+ "tests",
+]
+warn_unused_ignores = true
+show_error_codes = true
+disable_bytearray_promotion = true
+disable_memoryview_promotion = true
+strict = true
+
+[[tool.mypy.overrides]]
+module = [
+ "shapely.*",
+]
+ignore_missing_imports = true
+
+[[tool.mypy.overrides]]
+module = "uvloop"
+ignore_missing_imports = true
+
+[[tool.mypy.overrides]]
+module = "tests.*"
+check_untyped_defs = true
+disallow_untyped_defs = false
+disallow_untyped_calls = false
diff --git a/tests/README.rst b/tests/README.rst
new file mode 100644
index 0000000..63c7238
--- /dev/null
+++ b/tests/README.rst
@@ -0,0 +1,94 @@
+psycopg test suite
+===================
+
+Quick version
+-------------
+
+To run tests on the current code you can install the `test` extra of the
+package, specify a connection string in the `PSYCOPG_TEST_DSN` env var to
+connect to a test database, and run ``pytest``::
+
+ $ pip install -e "psycopg[test]"
+ $ export PSYCOPG_TEST_DSN="host=localhost dbname=psycopg_test"
+ $ pytest
+
+
+Test options
+------------
+
+- The tests output header shows additional psycopg related information,
+ on top of the one normally displayed by ``pytest`` and the extensions used::
+
+ $ pytest
+ ========================= test session starts =========================
+ platform linux -- Python 3.8.5, pytest-6.0.2, py-1.10.0, pluggy-0.13.1
+ Using --randomly-seed=2416596601
+ libpq available: 130002
+ libpq wrapper implementation: c
+
+
+- By default the tests run using the ``pq`` implementation that psycopg would
+ choose (the C module if installed, else the Python one). In order to test a
+ different implementation, use the normal `pq module selection mechanism`__
+ of the ``PSYCOPG_IMPL`` env var::
+
+ $ PSYCOPG_IMPL=python pytest
+ ========================= test session starts =========================
+ [...]
+ libpq available: 130002
+ libpq wrapper implementation: python
+
+ .. __: https://www.psycopg.org/psycopg/docs/api/pq.html#pq-module-implementations
+
+
+- Slow tests have a ``slow`` marker which can be selected to reduce test
+ runtime to a few seconds only. Please add a ``@pytest.mark.slow`` marker to
+ any test needing an arbitrary wait. At the time of writing::
+
+ $ pytest
+ ========================= test session starts =========================
+ [...]
+ ======= 1983 passed, 3 skipped, 110 xfailed in 78.21s (0:01:18) =======
+
+ $ pytest -m "not slow"
+ ========================= test session starts =========================
+ [...]
+ ==== 1877 passed, 2 skipped, 169 deselected, 48 xfailed in 13.47s =====
+
+- ``pytest`` option ``--pq-trace={TRACEFILE,STDERR}`` can be used to capture
+ libpq trace. When using ``stderr``, the output will only be shown for
+ failing or in-error tests, unless ``-s/--capture=no`` option is used.
+
+- ``pytest`` option ``--pq-debug`` can be used to log access to libpq's
+ ``PGconn`` functions.
+
+
+Testing in docker
+-----------------
+
+Useful to test different Python versions without installing them. Can be used
+to replicate GitHub actions failures, specifying the ``--randomly-seed`` used
+in the test run. The following ``PG*`` env vars are an example to adjust the
+test dsn in order to connect to a database running on the docker host: specify
+a set of env vars working for your setup::
+
+ $ docker run -ti --rm --volume `pwd`:/src --workdir /src \
+ -e PSYCOPG_TEST_DSN -e PGHOST=172.17.0.1 -e PGUSER=`whoami` \
+ python:3.7 bash
+
+ # pip install -e "./psycopg[test]" ./psycopg_pool ./psycopg_c
+ # pytest
+
+
+Testing with CockroachDB
+========================
+
+You can run CRDB in a docker container using::
+
+ docker run -p 26257:26257 --name crdb --rm \
+ cockroachdb/cockroach:v22.1.3 start-single-node --insecure
+
+And use the following connection string to run the tests::
+
+ export PSYCOPG_TEST_DSN="host=localhost port=26257 user=root dbname=defaultdb"
+ pytest ...
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/adapters_example.py b/tests/adapters_example.py
new file mode 100644
index 0000000..a184e6a
--- /dev/null
+++ b/tests/adapters_example.py
@@ -0,0 +1,45 @@
+from typing import Optional
+
+from psycopg import pq
+from psycopg.abc import Dumper, Loader, AdaptContext, PyFormat, Buffer
+
+
+def f() -> None:
+ d: Dumper = MyStrDumper(str, None)
+ assert d.dump("abc") == b"abcabc"
+ assert d.quote("abc") == b"'abcabc'"
+
+ lo: Loader = MyTextLoader(0, None)
+ assert lo.load(b"abc") == "abcabc"
+
+
+class MyStrDumper:
+ format = pq.Format.TEXT
+ oid = 25 # text
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ self._cls = cls
+
+ def dump(self, obj: str) -> bytes:
+ return (obj * 2).encode()
+
+ def quote(self, obj: str) -> bytes:
+ value = self.dump(obj)
+ esc = pq.Escaping()
+ return b"'%s'" % esc.escape_string(value.replace(b"h", b"q"))
+
+ def get_key(self, obj: str, format: PyFormat) -> type:
+ return self._cls
+
+ def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper":
+ return self
+
+
+class MyTextLoader:
+ format = pq.Format.TEXT
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ pass
+
+ def load(self, data: Buffer) -> str:
+ return (bytes(data) * 2).decode()
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..15bcf40
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,92 @@
+import sys
+import asyncio
+import selectors
+from typing import List
+
+pytest_plugins = (
+ "tests.fix_db",
+ "tests.fix_pq",
+ "tests.fix_mypy",
+ "tests.fix_faker",
+ "tests.fix_proxy",
+ "tests.fix_psycopg",
+ "tests.fix_crdb",
+ "tests.pool.fix_pool",
+)
+
+
+def pytest_configure(config):
+ markers = [
+ "slow: this test is kinda slow (skip with -m 'not slow')",
+ "flakey(reason): this test may fail unpredictably')",
+ # There are troubles on travis with these kind of tests and I cannot
+ # catch the exception for my life.
+ "subprocess: the test import psycopg after subprocess",
+ "timing: the test is timing based and can fail on cheese hardware",
+ "dns: the test requires dnspython to run",
+ "postgis: the test requires the PostGIS extension to run",
+ ]
+
+ for marker in markers:
+ config.addinivalue_line("markers", marker)
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--loop",
+ choices=["default", "uvloop"],
+ default="default",
+ help="The asyncio loop to use for async tests.",
+ )
+
+
+def pytest_report_header(config):
+ rv = []
+
+ rv.append(f"default selector: {selectors.DefaultSelector.__name__}")
+ loop = config.getoption("--loop")
+ if loop != "default":
+ rv.append(f"asyncio loop: {loop}")
+
+ return rv
+
+
+def pytest_sessionstart(session):
+ # Detect if there was a segfault in the previous run.
+ #
+ # In case of segfault, pytest doesn't get a chance to write failed tests
+ # in the cache. As a consequence, retries would find no test failed and
+ # assume that all tests passed in the previous run, making the whole test pass.
+ cache = session.config.cache
+ if cache.get("segfault", False):
+ session.warn(Warning("Previous run resulted in segfault! Not running any test"))
+ session.warn(Warning("(delete '.pytest_cache/v/segfault' to clear this state)"))
+ raise session.Failed
+ cache.set("segfault", True)
+
+ # Configure the async loop.
+ loop = session.config.getoption("--loop")
+ if loop == "uvloop":
+ import uvloop
+
+ uvloop.install()
+ else:
+ assert loop == "default"
+
+ if sys.platform == "win32":
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+
+
+allow_fail_messages: List[str] = []
+
+
+def pytest_sessionfinish(session, exitstatus):
+ # Mark the test run successful (in the sense -weak- that we didn't segfault).
+ session.config.cache.set("segfault", False)
+
+
+def pytest_terminal_summary(terminalreporter, exitstatus, config):
+ if allow_fail_messages:
+ terminalreporter.section("failed tests ignored")
+ for msg in allow_fail_messages:
+ terminalreporter.line(msg)
diff --git a/tests/constraints.txt b/tests/constraints.txt
new file mode 100644
index 0000000..ef03ba1
--- /dev/null
+++ b/tests/constraints.txt
@@ -0,0 +1,32 @@
+# This is a constraint file forcing the minimum allowed version to be
+# installed.
+#
+# https://pip.pypa.io/en/stable/user_guide/#constraints-files
+
+# From install_requires
+backports.zoneinfo == 0.2.0
+typing-extensions == 4.1.0
+
+# From the 'test' extra
+mypy == 0.981
+pproxy == 2.7.0
+pytest == 6.2.5
+pytest-asyncio == 0.17.0
+pytest-cov == 3.0.0
+pytest-randomly == 3.10.0
+
+# From the 'dev' extra
+black == 22.3.0
+dnspython == 2.1.0
+flake8 == 4.0.0
+mypy == 0.981
+types-setuptools == 57.4.0
+wheel == 0.37
+
+# From the 'docs' extra
+Sphinx == 4.2.0
+furo == 2021.11.23
+sphinx-autobuild == 2021.3.14
+sphinx-autodoc-typehints == 1.12.0
+dnspython == 2.1.0
+shapely == 1.7.0
diff --git a/tests/crdb/__init__.py b/tests/crdb/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/crdb/__init__.py
diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py
new file mode 100644
index 0000000..ce5bacf
--- /dev/null
+++ b/tests/crdb/test_adapt.py
@@ -0,0 +1,78 @@
+from copy import deepcopy
+
+import pytest
+
+from psycopg.crdb import adapters, CrdbConnection
+
+from psycopg.adapt import PyFormat, Transformer
+from psycopg.types.array import ListDumper
+from psycopg.postgres import types as builtins
+
+from ..test_adapt import MyStr, make_dumper, make_bin_dumper
+from ..test_adapt import make_loader, make_bin_loader
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_return_untyped(conn, fmt_in):
+ # Analyze and check for changes using strings in untyped/typed contexts
+ cur = conn.cursor()
+ # Currently string are passed as text oid to CockroachDB, unlike Postgres,
+ # to which strings are passed as unknown. This is because CRDB doesn't
+ # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises
+ # an error. However, unlike PostgreSQL, text can be cast to any other type.
+ cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10])
+ assert cur.fetchone() == ("hello", 10)
+
+ cur.execute("create table testjson(data jsonb)")
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+ assert cur.execute("select data from testjson").fetchone() == ({},)
+
+
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
+@pytest.fixture
+def crdb_adapters():
+ """Restore the crdb adapters after a test has changed them."""
+ dumpers = deepcopy(adapters._dumpers)
+ dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
+ loaders = deepcopy(adapters._loaders)
+ types = list(adapters.types)
+
+ yield None
+
+ adapters._dumpers = dumpers
+ adapters._dumpers_by_oid = dumpers_by_oid
+ adapters._loaders = loaders
+ adapters.types.clear()
+ for t in types:
+ adapters.types.add(t)
+
+
+def test_dump_global_ctx(dsn, crdb_adapters, pgconn):
+ adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ adapters.register_dumper(MyStr, make_dumper("gt"))
+ with CrdbConnection.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+
+
+def test_load_global_ctx(dsn, crdb_adapters):
+ adapters.register_loader("text", make_loader("gt"))
+ adapters.register_loader("text", make_bin_loader("gb"))
+ with CrdbConnection.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py
new file mode 100644
index 0000000..b2a69ef
--- /dev/null
+++ b/tests/crdb/test_connection.py
@@ -0,0 +1,86 @@
+import time
+import threading
+
+import psycopg.crdb
+from psycopg import errors as e
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_is_crdb(conn):
+ assert CrdbConnection.is_crdb(conn)
+ assert CrdbConnection.is_crdb(conn.pgconn)
+
+
+def test_connect(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ assert isinstance(conn, CrdbConnection)
+
+ with psycopg.crdb.connect(dsn) as conn:
+ assert isinstance(conn, CrdbConnection)
+
+
+def test_xid(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.xid(1, "gtrid", "bqual")
+
+
+def test_tpc_begin(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.tpc_begin("foo")
+
+
+def test_tpc_recover(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.tpc_recover()
+
+
+@pytest.mark.slow
+def test_broken_connection(conn):
+ cur = conn.cursor()
+ (session_id,) = cur.execute("select session_id from [show session_id]").fetchone()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("cancel session %s", [session_id])
+ assert conn.closed
+
+
+@pytest.mark.slow
+def test_broken(conn):
+ (session_id,) = conn.execute("show session_id").fetchone()
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("cancel session %s", [session_id])
+
+ assert conn.closed
+ assert conn.broken
+ conn.close()
+ assert conn.closed
+ assert conn.broken
+
+
+@pytest.mark.slow
+def test_identify_closure(conn_cls, dsn):
+ with conn_cls.connect(dsn, autocommit=True) as conn:
+ with conn_cls.connect(dsn, autocommit=True) as conn2:
+ (session_id,) = conn.execute("show session_id").fetchone()
+
+ def closer():
+ time.sleep(0.2)
+ conn2.execute("cancel session %s", [session_id])
+
+ t = threading.Thread(target=closer)
+ t.start()
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("select pg_sleep(3.0)")
+ dt = time.time() - t0
+ # CRDB seems to take not less than 1s
+ assert 0.2 < dt < 2
+ finally:
+ t.join()
diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py
new file mode 100644
index 0000000..b568e42
--- /dev/null
+++ b/tests/crdb/test_connection_async.py
@@ -0,0 +1,85 @@
+import time
+import asyncio
+
+import psycopg.crdb
+from psycopg import errors as e
+from psycopg.crdb import AsyncCrdbConnection
+from psycopg._compat import create_task
+
+import pytest
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+async def test_is_crdb(aconn):
+ assert AsyncCrdbConnection.is_crdb(aconn)
+ assert AsyncCrdbConnection.is_crdb(aconn.pgconn)
+
+
+async def test_connect(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection)
+
+
+async def test_xid(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.xid(1, "gtrid", "bqual")
+
+
+async def test_tpc_begin(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ await conn.tpc_begin("foo")
+
+
+async def test_tpc_recover(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ await conn.tpc_recover()
+
+
+@pytest.mark.slow
+async def test_broken_connection(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select session_id from [show session_id]")
+ (session_id,) = await cur.fetchone()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("cancel session %s", [session_id])
+ assert aconn.closed
+
+
+@pytest.mark.slow
+async def test_broken(aconn):
+ cur = await aconn.execute("show session_id")
+ (session_id,) = await cur.fetchone()
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute("cancel session %s", [session_id])
+
+ assert aconn.closed
+ assert aconn.broken
+ await aconn.close()
+ assert aconn.closed
+ assert aconn.broken
+
+
+@pytest.mark.slow
+async def test_identify_closure(aconn_cls, dsn):
+ async with await aconn_cls.connect(dsn) as conn:
+ async with await aconn_cls.connect(dsn) as conn2:
+ cur = await conn.execute("show session_id")
+ (session_id,) = await cur.fetchone()
+
+ async def closer():
+ await asyncio.sleep(0.2)
+ await conn2.execute("cancel session %s", [session_id])
+
+ t = create_task(closer())
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ await conn.execute("select pg_sleep(3.0)")
+ dt = time.time() - t0
+ assert 0.2 < dt < 2
+ finally:
+ await asyncio.gather(t)
diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py
new file mode 100644
index 0000000..274a0c0
--- /dev/null
+++ b/tests/crdb/test_conninfo.py
@@ -0,0 +1,21 @@
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_vendor(conn):
+ assert conn.info.vendor == "CockroachDB"
+
+
+def test_server_version(conn):
+ assert conn.info.server_version > 200000
+
+
+@pytest.mark.crdb("< 22")
+def test_backend_pid_pre_22(conn):
+ assert conn.info.backend_pid == 0
+
+
+@pytest.mark.crdb(">= 22")
+def test_backend_pid(conn):
+ assert conn.info.backend_pid > 0
diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py
new file mode 100644
index 0000000..b7d26aa
--- /dev/null
+++ b/tests/crdb/test_copy.py
@@ -0,0 +1,233 @@
+import pytest
+import string
+from random import randrange, choice
+
+from psycopg import sql, errors as e
+from psycopg.pq import Format
+from psycopg.adapt import PyFormat
+from psycopg.types.numeric import Int4
+
+from ..utils import eur, gc_collect, gc_count
+from ..test_copy import sample_text, sample_binary # noqa
+from ..test_copy import ensure_table, sample_records
+from ..test_copy import sample_tabledef as sample_tabledef_pg
+
+# CRDB int/serial are int8
+sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4")
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_copy_in_buffers(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_buffers_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_str(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text.decode())
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
+def test_copy_in_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ with cur.copy("copy copy_in from stdin with binary") as copy:
+ copy.write(sample_text.decode())
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+def test_copy_big_size_record(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "id serial primary key, data text")
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write_row([data])
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.slow
+def test_copy_big_size_block(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "id serial primary key, data text")
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n"
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write(copy_data)
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+def test_copy_in_buffers_with_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_set_types(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_binary(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+
+ with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ copy.write_row((None, row[2]))
+
+ data = cur.execute("select col2, data from copy_in order by 2").fetchall()
+ assert data == [(None, "hello"), (None, "world")]
+
+
+@pytest.mark.crdb_skip("copy canceled")
+def test_copy_in_buffers_with_py_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_allchars(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 int primary key, col2 int, data text")
+
+ with cur.copy("copy copy_in from stdin") as copy:
+ for i in range(1, 256):
+ copy.write_row((i, None, chr(i)))
+ copy.write_row((ord(eur), None, eur))
+
+ data = cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ ).fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.crdb_skip("copy array")
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin {}").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL("with binary" if fmt else ""),
+ )
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ copy.write_row(row)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+def copyopt(format):
+ return "with binary" if format == Format.BINARY else ""
diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py
new file mode 100644
index 0000000..d5fbf50
--- /dev/null
+++ b/tests/crdb/test_copy_async.py
@@ -0,0 +1,235 @@
+import pytest
+import string
+from random import randrange, choice
+
+from psycopg.pq import Format
+from psycopg import sql, errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types.numeric import Int4
+
+from ..utils import eur, gc_collect, gc_count
+from ..test_copy import sample_text, sample_binary # noqa
+from ..test_copy import sample_records
+from ..test_copy_async import ensure_table
+from .test_copy import sample_tabledef, copyopt
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_in_buffers(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_buffers_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_str(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text.decode())
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
+async def test_copy_in_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ async with cur.copy("copy copy_in from stdin with binary") as copy:
+ await copy.write(sample_text.decode())
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_empty(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+async def test_copy_big_size_record(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "id serial primary key, data text")
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write_row([data])
+
+ await cur.execute("select data from copy_in limit 1")
+ assert (await cur.fetchone())[0] == data
+
+
+@pytest.mark.slow
+async def test_copy_big_size_block(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "id serial primary key, data text")
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n"
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write(copy_data)
+
+ await cur.execute("select data from copy_in limit 1")
+ assert (await cur.fetchone())[0] == data
+
+
+async def test_copy_in_buffers_with_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_set_types(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_binary(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+
+ async with cur.copy(
+ f"copy copy_in (col2, data) from stdin {copyopt(format)}"
+ ) as copy:
+ for row in sample_records:
+ await copy.write_row((None, row[2]))
+
+ await cur.execute("select col2, data from copy_in order by 2")
+ data = await cur.fetchall()
+ assert data == [(None, "hello"), (None, "world")]
+
+
+@pytest.mark.crdb_skip("copy canceled")
+async def test_copy_in_buffers_with_py_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_allchars(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 int primary key, col2 int, data text")
+
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, None, chr(i)))
+ await copy.write_row((ord(eur), None, eur))
+
+ await cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ )
+ data = await cur.fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.crdb_skip("copy array")
+async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn:
+ async with conn.cursor(binary=fmt) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin {}").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL("with binary" if fmt else ""),
+ )
+ async with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ await copy.write_row(row)
+
+ await cur.execute(faker.select_stmt)
+ recs = await cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py
new file mode 100644
index 0000000..991b084
--- /dev/null
+++ b/tests/crdb/test_cursor.py
@@ -0,0 +1,65 @@
+import json
+import threading
+from uuid import uuid4
+from queue import Queue
+from typing import Any
+
+import pytest
+from psycopg import pq, errors as e
+from psycopg.rows import namedtuple_row
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.fixture
+def testfeed(svcconn):
+ name = f"test_feed_{str(uuid4()).replace('-', '')}"
+ svcconn.execute("set cluster setting kv.rangefeed.enabled to true")
+ svcconn.execute(f"create table {name} (id serial primary key, data text)")
+ yield name
+ svcconn.execute(f"drop table {name}")
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out):
+ conn.autocommit = True
+ q: "Queue[Any]" = Queue()
+
+ def worker():
+ try:
+ with conn_cls.connect(dsn, autocommit=True) as conn:
+ cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row)
+ try:
+ for row in cur.stream(f"experimental changefeed for {testfeed}"):
+ q.put(row)
+ except e.QueryCanceled:
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ q.put(None)
+ except Exception as ex:
+ q.put(ex)
+
+ t = threading.Thread(target=worker)
+ t.start()
+
+ cur = conn.cursor()
+ cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
+ (key,) = cur.fetchone()
+ row = q.get(timeout=1)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
+
+ cur.execute(f"delete from {testfeed} where id = %s", [key])
+ row = q.get(timeout=1)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] is None
+
+ cur.execute("select query_id from [show statements] where query !~ 'show'")
+ (qid,) = cur.fetchone()
+ cur.execute("cancel query %s", [qid])
+ assert cur.statusmessage == "CANCEL QUERIES 1"
+
+ assert q.get(timeout=1) is None
+ t.join()
diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py
new file mode 100644
index 0000000..229295d
--- /dev/null
+++ b/tests/crdb/test_cursor_async.py
@@ -0,0 +1,61 @@
+import json
+import asyncio
+from typing import Any
+from asyncio.queues import Queue
+
+import pytest
+from psycopg import pq, errors as e
+from psycopg.rows import namedtuple_row
+from psycopg._compat import create_task
+
+from .test_cursor import testfeed
+
+testfeed # fixture
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out):
+ await aconn.set_autocommit(True)
+ q: "Queue[Any]" = Queue()
+
+ async def worker():
+ try:
+ async with await aconn_cls.connect(dsn, autocommit=True) as conn:
+ cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row)
+ try:
+ async for row in cur.stream(
+ f"experimental changefeed for {testfeed}"
+ ):
+ q.put_nowait(row)
+ except e.QueryCanceled:
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ q.put_nowait(None)
+ except Exception as ex:
+ q.put_nowait(ex)
+
+ t = create_task(worker())
+
+ cur = aconn.cursor()
+ await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
+ (key,) = await cur.fetchone()
+ row = await asyncio.wait_for(q.get(), 1.0)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
+
+ await cur.execute(f"delete from {testfeed} where id = %s", [key])
+ row = await asyncio.wait_for(q.get(), 1.0)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] is None
+
+ await cur.execute("select query_id from [show statements] where query !~ 'show'")
+ (qid,) = await cur.fetchone()
+ await cur.execute("cancel query %s", [qid])
+ assert cur.statusmessage == "CANCEL QUERIES 1"
+
+ assert await asyncio.wait_for(q.get(), 1.0) is None
+ await asyncio.gather(t)
diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py
new file mode 100644
index 0000000..df43f3b
--- /dev/null
+++ b/tests/crdb/test_no_crdb.py
@@ -0,0 +1,34 @@
+from psycopg.pq import TransactionStatus
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb("skip")
+
+
+def test_is_crdb(conn):
+ assert not CrdbConnection.is_crdb(conn)
+ assert not CrdbConnection.is_crdb(conn.pgconn)
+
+
+def test_tpc_on_pg_connection(conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py
new file mode 100644
index 0000000..2cff0a7
--- /dev/null
+++ b/tests/crdb/test_typing.py
@@ -0,0 +1,49 @@
+import pytest
+
+from ..test_typing import _test_reveal
+
+
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "psycopg.crdb.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect()",
+ "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_connection_type(conn, type, mypy):
+ stmts = f"obj = {conn}"
+ _test_reveal_crdb(stmts, type, mypy)
+
+
+def _test_reveal_crdb(stmts, type, mypy):
+ stmts = f"""\
+import psycopg.crdb
+{stmts}
+"""
+ _test_reveal(stmts, type, mypy)
diff --git a/tests/dbapi20.py b/tests/dbapi20.py
new file mode 100644
index 0000000..c873a4e
--- /dev/null
+++ b/tests/dbapi20.py
@@ -0,0 +1,870 @@
+#!/usr/bin/env python
+# flake8: noqa
+# fmt: off
+''' Python DB API 2.0 driver compliance unit test suite.
+
+ This software is Public Domain and may be used without restrictions.
+
+ "Now we have booze and barflies entering the discussion, plus rumours of
+ DBAs on drugs... and I won't tell you what flashes through my mind each
+ time I read the subject line with 'Anal Compliance' in it. All around
+ this is turning out to be a thoroughly unwholesome unit test."
+
+ -- Ian Bicking
+'''
+
+__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
+__version__ = '$Revision: 1.12 $'[11:-2]
+__author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
+
+import unittest
+import time
+import sys
+from typing import Any, Dict
+
+
+# Revision 1.12 2009/02/06 03:35:11 kf7xm
+# Tested okay with Python 3.0, includes last minute patches from Mark H.
+#
+# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole
+# Include latest changes from main branch
+# Updates for py3k
+#
+# Revision 1.11 2005/01/02 02:41:01 zenzen
+# Update author email address
+#
+# Revision 1.10 2003/10/09 03:14:14 zenzen
+# Add test for DB API 2.0 optional extension, where database exceptions
+# are exposed as attributes on the Connection object.
+#
+# Revision 1.9 2003/08/13 01:16:36 zenzen
+# Minor tweak from Stefan Fleiter
+#
+# Revision 1.8 2003/04/10 00:13:25 zenzen
+# Changes, as per suggestions by M.-A. Lemburg
+# - Add a table prefix, to ensure namespace collisions can always be avoided
+#
+# Revision 1.7 2003/02/26 23:33:37 zenzen
+# Break out DDL into helper functions, as per request by David Rushby
+#
+# Revision 1.6 2003/02/21 03:04:33 zenzen
+# Stuff from Henrik Ekelund:
+# added test_None
+# added test_nextset & hooks
+#
+# Revision 1.5 2003/02/17 22:08:43 zenzen
+# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
+# defaults to 1 & generic cursor.callproc test added
+#
+# Revision 1.4 2003/02/15 00:16:33 zenzen
+# Changes, as per suggestions and bug reports by M.-A. Lemburg,
+# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
+# - Class renamed
+# - Now a subclass of TestCase, to avoid requiring the driver stub
+# to use multiple inheritance
+# - Reversed the polarity of buggy test in test_description
+# - Test exception hierarchy correctly
+# - self.populate is now self._populate(), so if a driver stub
+# overrides self.ddl1 this change propagates
+# - VARCHAR columns now have a width, which will hopefully make the
+# DDL even more portible (this will be reversed if it causes more problems)
+# - cursor.rowcount being checked after various execute and fetchXXX methods
+# - Check for fetchall and fetchmany returning empty lists after results
+# are exhausted (already checking for empty lists if select retrieved
+# nothing
+# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
+#
+
+class DatabaseAPI20Test(unittest.TestCase):
+ ''' Test a database self.driver for DB API 2.0 compatibility.
+ This implementation tests Gadfly, but the TestCase
+ is structured so that other self.drivers can subclass this
+ test case to ensure compiliance with the DB-API. It is
+ expected that this TestCase may be expanded in the future
+ if ambiguities or edge conditions are discovered.
+
+ The 'Optional Extensions' are not yet being tested.
+
+ self.drivers should subclass this test, overriding setUp, tearDown,
+ self.driver, connect_args and connect_kw_args. Class specification
+ should be as follows:
+
+ from . import dbapi20
+ class mytest(dbapi20.DatabaseAPI20Test):
+ [...]
+
+ Don't 'from .dbapi20 import DatabaseAPI20Test', or you will
+ confuse the unit tester - just 'from . import dbapi20'.
+ '''
+
+ # The self.driver module. This should be the module where the 'connect'
+ # method is to be found
+ driver: Any = None
+ connect_args = () # List of arguments to pass to connect
+ connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect
+ table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
+
+ ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
+ ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
+ xddl1 = 'drop table %sbooze' % table_prefix
+ xddl2 = 'drop table %sbarflys' % table_prefix
+
+ lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
+
+ # Some drivers may need to override these helpers, for example adding
+ # a 'commit' after the execute.
+ def executeDDL1(self,cursor):
+ cursor.execute(self.ddl1)
+
+ def executeDDL2(self,cursor):
+ cursor.execute(self.ddl2)
+
+ def setUp(self):
+ ''' self.drivers should override this method to perform required setup
+ if any is necessary, such as creating the database.
+ '''
+ pass
+
+ def tearDown(self):
+ ''' self.drivers should override this method to perform required cleanup
+ if any is necessary, such as deleting the test database.
+ The default drops the tables that may be created.
+ '''
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ for ddl in (self.xddl1,self.xddl2):
+ try:
+ cur.execute(ddl)
+ con.commit()
+ except self.driver.Error:
+ # Assume table didn't exist. Other tests will check if
+ # execute is busted.
+ pass
+ finally:
+ con.close()
+
+ def _connect(self):
+ try:
+ return self.driver.connect(
+ *self.connect_args,**self.connect_kw_args
+ )
+ except AttributeError:
+ self.fail("No connect method found in self.driver module")
+
+ def test_connect(self):
+ con = self._connect()
+ con.close()
+
+ def test_apilevel(self):
+ try:
+ # Must exist
+ apilevel = self.driver.apilevel
+ # Must equal 2.0
+ self.assertEqual(apilevel,'2.0')
+ except AttributeError:
+ self.fail("Driver doesn't define apilevel")
+
+ def test_threadsafety(self):
+ try:
+ # Must exist
+ threadsafety = self.driver.threadsafety
+ # Must be a valid value
+ self.failUnless(threadsafety in (0,1,2,3))
+ except AttributeError:
+ self.fail("Driver doesn't define threadsafety")
+
+ def test_paramstyle(self):
+ try:
+ # Must exist
+ paramstyle = self.driver.paramstyle
+ # Must be a valid value
+ self.failUnless(paramstyle in (
+ 'qmark','numeric','named','format','pyformat'
+ ))
+ except AttributeError:
+ self.fail("Driver doesn't define paramstyle")
+
+ def test_Exceptions(self):
+ # Make sure required exceptions exist, and are in the
+ # defined hierarchy.
+ if sys.version[0] == '3': #under Python 3 StardardError no longer exists
+ self.failUnless(issubclass(self.driver.Warning,Exception))
+ self.failUnless(issubclass(self.driver.Error,Exception))
+ else:
+ self.failUnless(issubclass(self.driver.Warning,StandardError)) # type: ignore[name-defined]
+ self.failUnless(issubclass(self.driver.Error,StandardError)) # type: ignore[name-defined]
+
+ self.failUnless(
+ issubclass(self.driver.InterfaceError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.DatabaseError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.OperationalError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.IntegrityError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.InternalError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.ProgrammingError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.NotSupportedError,self.driver.Error)
+ )
+
+ def test_ExceptionsAsConnectionAttributes(self):
+ # OPTIONAL EXTENSION
+ # Test for the optional DB API 2.0 extension, where the exceptions
+ # are exposed as attributes on the Connection object
+ # I figure this optional extension will be implemented by any
+ # driver author who is using this test suite, so it is enabled
+ # by default.
+ con = self._connect()
+ drv = self.driver
+ self.failUnless(con.Warning is drv.Warning)
+ self.failUnless(con.Error is drv.Error)
+ self.failUnless(con.InterfaceError is drv.InterfaceError)
+ self.failUnless(con.DatabaseError is drv.DatabaseError)
+ self.failUnless(con.OperationalError is drv.OperationalError)
+ self.failUnless(con.IntegrityError is drv.IntegrityError)
+ self.failUnless(con.InternalError is drv.InternalError)
+ self.failUnless(con.ProgrammingError is drv.ProgrammingError)
+ self.failUnless(con.NotSupportedError is drv.NotSupportedError)
+ con.close()
+
+
+ def test_commit(self):
+ con = self._connect()
+ try:
+ # Commit must work, even if it doesn't do anything
+ con.commit()
+ finally:
+ con.close()
+
+ def test_rollback(self):
+ con = self._connect()
+ # If rollback is defined, it should either work or throw
+ # the documented exception
+ if hasattr(con,'rollback'):
+ try:
+ con.rollback()
+ except self.driver.NotSupportedError:
+ pass
+ con.close()
+
+ def test_cursor(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ def test_cursor_isolation(self):
+ con = self._connect()
+ try:
+ # Make sure cursors created from the same connection have
+ # the documented transaction isolation level
+ cur1 = con.cursor()
+ cur2 = con.cursor()
+ self.executeDDL1(cur1)
+ cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ cur2.execute("select name from %sbooze" % self.table_prefix)
+ booze = cur2.fetchall()
+ self.assertEqual(len(booze),1)
+ self.assertEqual(len(booze[0]),1)
+ self.assertEqual(booze[0][0],'Victoria Bitter')
+ finally:
+ con.close()
+
+ def test_description(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description should be none after executing a '
+ 'statement that can return no rows (such as DDL)'
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(len(cur.description),1,
+ 'cursor.description describes too many columns'
+ )
+ self.assertEqual(len(cur.description[0]),7,
+ 'cursor.description[x] tuples must have 7 elements'
+ )
+ self.assertEqual(cur.description[0][0].lower(),'name',
+ 'cursor.description[x][0] must return column name'
+ )
+ self.assertEqual(cur.description[0][1],self.driver.STRING,
+ 'cursor.description[x][1] must return column type. Got %r'
+ % cur.description[0][1]
+ )
+
+ # Make sure self.description gets reset
+ self.executeDDL2(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description not being set to None when executing '
+ 'no-result statements (eg. DDL)'
+ )
+ finally:
+ con.close()
+
+ def test_rowcount(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount should be -1 after executing no-result '
+ 'statements'
+ )
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.failUnless(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number or rows inserted, or '
+ 'set to -1 after executing an insert statement'
+ )
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.failUnless(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number of rows returned, or '
+ 'set to -1 after executing a select statement'
+ )
+ self.executeDDL2(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount not being reset to -1 after executing '
+ 'no-result statements'
+ )
+ finally:
+ con.close()
+
+ lower_func = 'lower'
+ def test_callproc(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if self.lower_func and hasattr(cur,'callproc'):
+ r = cur.callproc(self.lower_func,('FOO',))
+ self.assertEqual(len(r),1)
+ self.assertEqual(r[0],'FOO')
+ r = cur.fetchall()
+ self.assertEqual(len(r),1,'callproc produced no result set')
+ self.assertEqual(len(r[0]),1,
+ 'callproc produced invalid result set'
+ )
+ self.assertEqual(r[0][0],'foo',
+ 'callproc produced invalid results'
+ )
+ finally:
+ con.close()
+
+ def test_close(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ # cursor.execute should raise an Error if called after connection
+ # closed
+ self.assertRaises(self.driver.Error,self.executeDDL1,cur)
+
+ # connection.commit should raise an Error if called after connection'
+ # closed.'
+ self.assertRaises(self.driver.Error,con.commit)
+
+ # connection.close should raise an Error if called more than once
+ # Issue discussed on DB-SIG: consensus seem that close() should not
+ # raised if called on closed objects. Issue reported back to Stuart.
+ # self.assertRaises(self.driver.Error,con.close)
+
+ def test_execute(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self._paraminsert(cur)
+ finally:
+ con.close()
+
+ def _paraminsert(self,cur):
+ self.executeDDL1(cur)
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.failUnless(cur.rowcount in (-1,1))
+
+ if self.driver.paramstyle == 'qmark':
+ cur.execute(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.execute(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.execute(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.execute(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.execute(
+ 'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ else:
+ self.fail('Invalid paramstyle')
+ self.failUnless(cur.rowcount in (-1,1))
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Cooper's",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+ self.assertEqual(beers[1],"Victoria Bitter",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+
+ def test_executemany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ largs = [ ("Cooper's",) , ("Boag's",) ]
+ margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
+ if self.driver.paramstyle == 'qmark':
+ cur.executemany(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.executemany(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.executemany(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ margs
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.executemany(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.executemany(
+ 'insert into %sbooze values (%%(beer)s)' % (
+ self.table_prefix
+ ),
+ margs
+ )
+ else:
+ self.fail('Unknown paramstyle')
+ self.failUnless(cur.rowcount in (-1,2),
+ 'insert using cursor.executemany set cursor.rowcount to '
+ 'incorrect value %r' % cur.rowcount
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,
+ 'cursor.fetchall retrieved incorrect number of rows'
+ )
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
+ self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
+ finally:
+ con.close()
+
+ def test_fetchone(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchone should raise an Error if called before
+ # executing a select-type query
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannot return rows
+ self.executeDDL1(cur)
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if a query retrieves '
+ 'no rows'
+ )
+ self.failUnless(cur.rowcount in (-1,0))
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannot return rows
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchone()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchone should have retrieved a single row'
+ )
+ self.assertEqual(r[0],'Victoria Bitter',
+ 'cursor.fetchone retrieved incorrect data'
+ )
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if no more rows available'
+ )
+ self.failUnless(cur.rowcount in (-1,1))
+ finally:
+ con.close()
+
+ samples = [
+ 'Carlton Cold',
+ 'Carlton Draft',
+ 'Mountain Goat',
+ 'Redback',
+ 'Victoria Bitter',
+ 'XXXX'
+ ]
+
+ def _populate(self):
+ ''' Return a list of sql commands to setup the DB for the fetch
+ tests.
+ '''
+ populate = [
+ "insert into %sbooze values ('%s')" % (self.table_prefix,s)
+ for s in self.samples
+ ]
+ return populate
+
+ def test_fetchmany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchmany should raise an Error if called without
+ #issuing a query
+ self.assertRaises(self.driver.Error,cur.fetchmany,4)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchmany retrieved incorrect number of rows, '
+ 'default of arraysize is one.'
+ )
+ cur.arraysize=10
+ r = cur.fetchmany(3) # Should get 3 rows
+ self.assertEqual(len(r),3,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should get 2 more
+ self.assertEqual(len(r),2,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should be an empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence after '
+ 'results are exhausted'
+ )
+ self.failUnless(cur.rowcount in (-1,6))
+
+ # Same as above, using cursor.arraysize
+ cur.arraysize=4
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany() # Should get 4 rows
+ self.assertEqual(len(r),4,
+ 'cursor.arraysize not being honoured by fetchmany'
+ )
+ r = cur.fetchmany() # Should get 2 more
+ self.assertEqual(len(r),2)
+ r = cur.fetchmany() # Should be an empty sequence
+ self.assertEqual(len(r),0)
+ self.failUnless(cur.rowcount in (-1,6))
+
+ cur.arraysize=6
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchmany() # Should get all rows
+ self.failUnless(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows),6)
+ self.assertEqual(len(rows),6)
+ rows = [r[0] for r in rows]
+ rows.sort()
+
+ # Make sure we get the right data back out
+ for i in range(0,6):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved by cursor.fetchmany'
+ )
+
+ rows = cur.fetchmany() # Should return an empty list
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'called after the whole result set has been fetched'
+ )
+ self.failUnless(cur.rowcount in (-1,6))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ r = cur.fetchmany() # Should get empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'query retrieved no rows'
+ )
+ self.failUnless(cur.rowcount in (-1,0))
+
+ finally:
+ con.close()
+
+ def test_fetchall(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ # cursor.fetchall should raise an Error if called
+ # without executing a query that may return rows (such
+ # as a select)
+ self.assertRaises(self.driver.Error, cur.fetchall)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ # cursor.fetchall should raise an Error if called
+ # after executing a a statement that cannot return rows
+ self.assertRaises(self.driver.Error,cur.fetchall)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchall()
+ self.failUnless(cur.rowcount in (-1,len(self.samples)))
+ self.assertEqual(len(rows),len(self.samples),
+ 'cursor.fetchall did not retrieve all rows'
+ )
+ rows = [r[0] for r in rows]
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'cursor.fetchall retrieved incorrect rows'
+ )
+ rows = cur.fetchall()
+ self.assertEqual(
+ len(rows),0,
+ 'cursor.fetchall should return an empty list if called '
+ 'after the whole result set has been fetched'
+ )
+ self.failUnless(cur.rowcount in (-1,len(self.samples)))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ rows = cur.fetchall()
+ self.failUnless(cur.rowcount in (-1,0))
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchall should return an empty list if '
+ 'a select query returns no rows'
+ )
+
+ finally:
+ con.close()
+
+ def test_mixedfetch(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows1 = cur.fetchone()
+ rows23 = cur.fetchmany(2)
+ rows4 = cur.fetchone()
+ rows56 = cur.fetchall()
+ self.failUnless(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows23),2,
+ 'fetchmany returned incorrect number of rows'
+ )
+ self.assertEqual(len(rows56),2,
+ 'fetchall returned incorrect number of rows'
+ )
+
+ rows = [rows1[0]]
+ rows.extend([rows23[0][0],rows23[1][0]])
+ rows.append(rows4[0])
+ rows.extend([rows56[0][0],rows56[1][0]])
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved or inserted'
+ )
+ finally:
+ con.close()
+
+ def help_nextset_setUp(self,cur):
+ ''' Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ '''
+ raise NotImplementedError('Helper not implemented')
+ #sql="""
+ # create procedure deleteme as
+ # begin
+ # select count(*) from booze
+ # select name from booze
+ # end
+ #"""
+ #cur.execute(sql)
+
+ def help_nextset_tearDown(self,cur):
+ 'If cleaning up is needed after nextSetTest'
+ raise NotImplementedError('Helper not implemented')
+ #cur.execute("drop procedure deleteme")
+
+ def test_nextset(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if not hasattr(cur,'nextset'):
+ return
+
+ try:
+ self.executeDDL1(cur)
+ sql=self._populate()
+ for sql in self._populate():
+ cur.execute(sql)
+
+ self.help_nextset_setUp(cur)
+
+ cur.callproc('deleteme')
+ numberofrows=cur.fetchone()
+ assert numberofrows[0]== len(self.samples)
+ assert cur.nextset()
+ names=cur.fetchall()
+ assert len(names) == len(self.samples)
+ s=cur.nextset()
+ assert s is None, 'No more return sets, should return None'
+ finally:
+ self.help_nextset_tearDown(cur)
+
+ finally:
+ con.close()
+
+ def test_arraysize(self):
+ # Not much here - rest of the tests for this are in test_fetchmany
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.failUnless(hasattr(cur,'arraysize'),
+ 'cursor.arraysize must be defined'
+ )
+ finally:
+ con.close()
+
+ def test_setinputsizes(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setinputsizes( (25,) )
+ self._paraminsert(cur) # Make sure cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize_basic(self):
+ # Basic test is to make sure setoutputsize doesn't blow up
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setoutputsize(1000)
+ cur.setoutputsize(2000,0)
+ self._paraminsert(cur) # Make sure the cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize(self):
+ # Real test for setoutputsize is driver dependent
+ raise NotImplementedError('Driver needed to override this test')
+
+ def test_None(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchall()
+ self.assertEqual(len(r),1)
+ self.assertEqual(len(r[0]),1)
+ self.assertEqual(r[0][0],None,'NULL value not returned as None')
+ finally:
+ con.close()
+
+ def test_Date(self):
+ d1 = self.driver.Date(2002,12,25)
+ d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(d1),str(d2))
+
+ def test_Time(self):
+ t1 = self.driver.Time(13,45,30)
+ t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Timestamp(self):
+ t1 = self.driver.Timestamp(2002,12,25,13,45,30)
+ t2 = self.driver.TimestampFromTicks(
+ time.mktime((2002,12,25,13,45,30,0,0,0))
+ )
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Binary(self):
+ b = self.driver.Binary(b'Something')
+ b = self.driver.Binary(b'')
+
+ def test_STRING(self):
+ self.failUnless(hasattr(self.driver,'STRING'),
+ 'module.STRING must be defined'
+ )
+
+ def test_BINARY(self):
+ self.failUnless(hasattr(self.driver,'BINARY'),
+ 'module.BINARY must be defined.'
+ )
+
+ def test_NUMBER(self):
+ self.failUnless(hasattr(self.driver,'NUMBER'),
+ 'module.NUMBER must be defined.'
+ )
+
+ def test_DATETIME(self):
+ self.failUnless(hasattr(self.driver,'DATETIME'),
+ 'module.DATETIME must be defined.'
+ )
+
+ def test_ROWID(self):
+ self.failUnless(hasattr(self.driver,'ROWID'),
+ 'module.ROWID must be defined.'
+ )
+# fmt: on
diff --git a/tests/dbapi20_tpc.py b/tests/dbapi20_tpc.py
new file mode 100644
index 0000000..7254294
--- /dev/null
+++ b/tests/dbapi20_tpc.py
@@ -0,0 +1,151 @@
+# flake8: noqa
+# fmt: off
+
+""" Python DB API 2.0 driver Two Phase Commit compliance test suite.
+
+"""
+
+import unittest
+from typing import Any
+
+
+class TwoPhaseCommitTests(unittest.TestCase):
+
+ driver: Any = None
+
+ def connect(self):
+ """Make a database connection."""
+ raise NotImplementedError
+
+ _last_id = 0
+ _global_id_prefix = "dbapi20_tpc:"
+
+ def make_xid(self, con):
+ id = TwoPhaseCommitTests._last_id
+ TwoPhaseCommitTests._last_id += 1
+ return con.xid(42, f"{self._global_id_prefix}{id}", "qualifier")
+
+ def test_xid(self):
+ con = self.connect()
+ try:
+ try:
+ xid = con.xid(42, "global", "bqual")
+ except self.driver.NotSupportedError:
+ self.fail("Driver does not support transaction IDs.")
+
+ self.assertEquals(xid[0], 42)
+ self.assertEquals(xid[1], "global")
+ self.assertEquals(xid[2], "bqual")
+
+ # Try some extremes for the transaction ID:
+ xid = con.xid(0, "", "")
+ self.assertEquals(tuple(xid), (0, "", ""))
+ xid = con.xid(0x7fffffff, "a" * 64, "b" * 64)
+ self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64))
+ finally:
+ con.close()
+
+ def test_tpc_begin(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ try:
+ con.tpc_begin(xid)
+ except self.driver.NotSupportedError:
+ self.fail("Driver does not support tpc_begin()")
+ finally:
+ con.close()
+
+ def test_tpc_commit_without_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_commit()
+ finally:
+ con.close()
+
+ def test_tpc_rollback_without_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_rollback()
+ finally:
+ con.close()
+
+ def test_tpc_commit_with_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_prepare()
+ con.tpc_commit()
+ finally:
+ con.close()
+
+ def test_tpc_rollback_with_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_prepare()
+ con.tpc_rollback()
+ finally:
+ con.close()
+
+ def test_tpc_begin_in_transaction_fails(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ self.assertRaises(self.driver.ProgrammingError,
+ con.tpc_begin, xid)
+ finally:
+ con.close()
+
+ def test_tpc_begin_in_tpc_transaction_fails(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ self.assertRaises(self.driver.ProgrammingError,
+ con.tpc_begin, xid)
+ finally:
+ con.close()
+
+ def test_commit_in_tpc_fails(self):
+ # calling commit() within a TPC transaction fails with
+ # ProgrammingError.
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+
+ self.assertRaises(self.driver.ProgrammingError, con.commit)
+ finally:
+ con.close()
+
+ def test_rollback_in_tpc_fails(self):
+ # calling rollback() within a TPC transaction fails with
+ # ProgrammingError.
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+
+ self.assertRaises(self.driver.ProgrammingError, con.rollback)
+ finally:
+ con.close()
diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py
new file mode 100644
index 0000000..88ab504
--- /dev/null
+++ b/tests/fix_crdb.py
@@ -0,0 +1,131 @@
+from typing import Optional
+
+import pytest
+
+from .utils import VersionCheck
+from psycopg.crdb import CrdbConnection
+
+
+def pytest_configure(config):
+ # register libpq marker
+ config.addinivalue_line(
+ "markers",
+ "crdb(version_expr, reason=detail): run/skip the test with matching CockroachDB"
+ " (e.g. '>= 21.2.10', '< 22.1', 'skip < 22')",
+ )
+ config.addinivalue_line(
+ "markers",
+ "crdb_skip(reason): skip the test for known CockroachDB reasons",
+ )
+
+
+def check_crdb_version(got, mark):
+ if mark.name == "crdb":
+ assert len(mark.args) <= 1
+ assert not (set(mark.kwargs) - {"reason"})
+ spec = mark.args[0] if mark.args else "only"
+ reason = mark.kwargs.get("reason")
+ elif mark.name == "crdb_skip":
+ assert len(mark.args) == 1
+ assert not mark.kwargs
+ reason = mark.args[0]
+ assert reason in _crdb_reasons, reason
+ spec = _crdb_reason_version.get(reason, "skip")
+ else:
+ assert False, mark.name
+
+ pred = VersionCheck.parse(spec)
+ pred.whose = "CockroachDB"
+
+ msg = pred.get_skip_message(got)
+ if not msg:
+ return None
+
+ reason = crdb_skip_message(reason)
+ if reason:
+ msg = f"{msg}: {reason}"
+
+ return msg
+
+
+# Utility functions which can be imported in the test suite
+
+is_crdb = CrdbConnection.is_crdb
+
+
+def crdb_skip_message(reason: Optional[str]) -> str:
+ msg = ""
+ if reason:
+ msg = reason
+ if _crdb_reasons.get(reason):
+ url = (
+ "https://github.com/cockroachdb/cockroach/"
+ f"issues/{_crdb_reasons[reason]}"
+ )
+ msg = f"{msg} ({url})"
+
+ return msg
+
+
+def skip_crdb(*args, reason=None):
+ return pytest.param(*args, marks=pytest.mark.crdb("skip", reason=reason))
+
+
+def crdb_encoding(*args):
+ """Mark tests that fail on CockroachDB because of missing encodings"""
+ return skip_crdb(*args, reason="encoding")
+
+
+def crdb_time_precision(*args):
+ """Mark tests that fail on CockroachDB because time doesn't support precision"""
+ return skip_crdb(*args, reason="time precision")
+
+
+def crdb_scs_off(*args):
+ return skip_crdb(*args, reason="standard_conforming_strings=off")
+
+
+# mapping from reason description to ticket number
+_crdb_reasons = {
+ "2-phase commit": 22329,
+ "backend pid": 35897,
+ "batch statements": 44803,
+ "begin_read_only": 87012,
+ "binary decimal": 82492,
+ "cancel": 41335,
+ "cast adds tz": 51692,
+ "cidr": 18846,
+ "composite": 27792,
+ "copy array": 82792,
+ "copy canceled": 81559,
+ "copy": 41608,
+ "cursor invalid name": 84261,
+ "cursor with hold": 77101,
+ "deferrable": 48307,
+ "do": 17511,
+ "encoding": 35882,
+ "geometric types": 21286,
+ "hstore": 41284,
+ "infinity date": 41564,
+ "interval style": 35807,
+ "json array": 23468,
+ "large objects": 243,
+ "negative interval": 81577,
+ "nested array": 32552,
+ "no col query": None,
+ "notify": 41522,
+ "password_encryption": 42519,
+ "pg_terminate_backend": 35897,
+ "range": 41282,
+ "scroll cursor": 77102,
+ "server-side cursor": 41412,
+ "severity_nonlocalized": 81794,
+ "stored procedure": 1751,
+}
+
+_crdb_reason_version = {
+ "backend pid": "skip < 22",
+ "cancel": "skip < 22",
+ "server-side cursor": "skip < 22.1.3",
+ "severity_nonlocalized": "skip < 22.1.3",
+}
diff --git a/tests/fix_db.py b/tests/fix_db.py
new file mode 100644
index 0000000..3a37aa1
--- /dev/null
+++ b/tests/fix_db.py
@@ -0,0 +1,358 @@
+import io
+import os
+import sys
+import pytest
+import logging
+from contextlib import contextmanager
+from typing import Optional
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg._compat import cache
+from psycopg.pq._debug import PGconnDebug
+
+from .utils import check_postgres_version
+
+# Set by warm_up_database() the first time the dsn fixture is used
+pg_version: int
+crdb_version: Optional[int]
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--test-dsn",
+ metavar="DSN",
+ default=os.environ.get("PSYCOPG_TEST_DSN"),
+ help=(
+ "Connection string to run database tests requiring a connection"
+ " [you can also use the PSYCOPG_TEST_DSN env var]."
+ ),
+ )
+ parser.addoption(
+ "--pq-trace",
+ metavar="{TRACEFILE,STDERR}",
+ default=None,
+ help="Generate a libpq trace to TRACEFILE or STDERR.",
+ )
+ parser.addoption(
+ "--pq-debug",
+ action="store_true",
+ default=False,
+ help="Log PGconn access. (Requires PSYCOPG_IMPL=python.)",
+ )
+
+
+def pytest_report_header(config):
+ dsn = config.getoption("--test-dsn")
+ if dsn is None:
+ return []
+
+ try:
+ with psycopg.connect(dsn, connect_timeout=10) as conn:
+ server_version = conn.execute("select version()").fetchall()[0][0]
+ except Exception as ex:
+ server_version = f"unknown ({ex})"
+
+ return [
+ f"Server version: {server_version}",
+ ]
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ for name in item.fixturenames:
+ if name in ("pipeline", "apipeline"):
+ item.add_marker(pytest.mark.pipeline)
+ break
+
+
+def pytest_runtest_setup(item):
+ for m in item.iter_markers(name="pipeline"):
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+
+
+def pytest_configure(config):
+ # register pg marker
+ markers = [
+ "pg(version_expr): run the test only with matching server version"
+ " (e.g. '>= 10', '< 9.6')",
+ "pipeline: the test runs with connection in pipeline mode",
+ ]
+ for marker in markers:
+ config.addinivalue_line("markers", marker)
+
+
+@pytest.fixture(scope="session")
+def session_dsn(request):
+ """
+ Return the dsn used to connect to the `--test-dsn` database (session-wide).
+ """
+ dsn = request.config.getoption("--test-dsn")
+ if dsn is None:
+ pytest.skip("skipping test as no --test-dsn")
+
+ warm_up_database(dsn)
+ return dsn
+
+
+@pytest.fixture
+def dsn(session_dsn, request):
+ """Return the dsn used to connect to the `--test-dsn` database."""
+ check_connection_version(request.node)
+ return session_dsn
+
+
+@pytest.fixture(scope="session")
+def tracefile(request):
+ """Open and yield a file for libpq client/server communication traces if
+ --pq-tracefile option is set.
+ """
+ tracefile = request.config.getoption("--pq-trace")
+ if not tracefile:
+ yield None
+ return
+
+ if tracefile.lower() == "stderr":
+ try:
+ sys.stderr.fileno()
+ except io.UnsupportedOperation:
+ raise pytest.UsageError(
+ "cannot use stderr for --pq-trace (in-memory file?)"
+ ) from None
+
+ yield sys.stderr
+ return
+
+ with open(tracefile, "w") as f:
+ yield f
+
+
+@contextmanager
+def maybe_trace(pgconn, tracefile, function):
+ """Handle libpq client/server communication traces for a single test
+ function.
+ """
+ if tracefile is None:
+ yield None
+ return
+
+ if tracefile != sys.stderr:
+ title = f" {function.__module__}::{function.__qualname__} ".center(80, "=")
+ tracefile.write(title + "\n")
+ tracefile.flush()
+
+ pgconn.trace(tracefile.fileno())
+ try:
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ except psycopg.NotSupportedError:
+ pass
+ try:
+ yield None
+ finally:
+ pgconn.untrace()
+
+
+@pytest.fixture(autouse=True)
+def pgconn_debug(request):
+ if not request.config.getoption("--pq-debug"):
+ return
+ if pq.__impl__ != "python":
+ raise pytest.UsageError("set PSYCOPG_IMPL=python to use --pq-debug")
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
+ logger = logging.getLogger("psycopg.debug")
+ logger.setLevel(logging.INFO)
+ pq.PGconn = PGconnDebug
+
+
+@pytest.fixture
+def pgconn(dsn, request, tracefile):
+ """Return a PGconn connection open to `--test-dsn`."""
+ check_connection_version(request.node)
+
+ conn = pq.PGconn.connect(dsn.encode())
+ if conn.status != pq.ConnStatus.OK:
+ pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
+
+ with maybe_trace(conn, tracefile, request.function):
+ yield conn
+
+ conn.finish()
+
+
+@pytest.fixture
+def conn(conn_cls, dsn, request, tracefile):
+ """Return a `Connection` connected to the ``--test-dsn`` database."""
+ check_connection_version(request.node)
+
+ conn = conn_cls.connect(dsn)
+ with maybe_trace(conn.pgconn, tracefile, request.function):
+ yield conn
+ conn.close()
+
+
+@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"])
+def pipeline(request, conn):
+ if request.param:
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ with conn.pipeline() as p:
+ yield p
+ return
+ else:
+ yield None
+
+
+@pytest.fixture
+async def aconn(dsn, aconn_cls, request, tracefile):
+ """Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
+ check_connection_version(request.node)
+
+ conn = await aconn_cls.connect(dsn)
+ with maybe_trace(conn.pgconn, tracefile, request.function):
+ yield conn
+ await conn.close()
+
+
+@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"])
+async def apipeline(request, aconn):
+ if request.param:
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ async with aconn.pipeline() as p:
+ yield p
+ return
+ else:
+ yield None
+
+
+@pytest.fixture(scope="session")
+def conn_cls(session_dsn):
+ cls = psycopg.Connection
+ if crdb_version:
+ from psycopg.crdb import CrdbConnection
+
+ cls = CrdbConnection
+
+ return cls
+
+
+@pytest.fixture(scope="session")
+def aconn_cls(session_dsn):
+ cls = psycopg.AsyncConnection
+ if crdb_version:
+ from psycopg.crdb import AsyncCrdbConnection
+
+ cls = AsyncCrdbConnection
+
+ return cls
+
+
+@pytest.fixture(scope="session")
+def svcconn(conn_cls, session_dsn):
+ """
+ Return a session `Connection` connected to the ``--test-dsn`` database.
+ """
+ conn = conn_cls.connect(session_dsn, autocommit=True)
+ yield conn
+ conn.close()
+
+
+@pytest.fixture
+def commands(conn, monkeypatch):
+ """The list of commands issued internally by the test connection."""
+ yield patch_exec(conn, monkeypatch)
+
+
+@pytest.fixture
+def acommands(aconn, monkeypatch):
+ """The list of commands issued internally by the test async connection."""
+ yield patch_exec(aconn, monkeypatch)
+
+
+def patch_exec(conn, monkeypatch):
+ """Helper to implement the commands fixture both sync and async."""
+ _orig_exec_command = conn._exec_command
+ L = ListPopAll()
+
+ def _exec_command(command, *args, **kwargs):
+ cmdcopy = command
+ if isinstance(cmdcopy, bytes):
+ cmdcopy = cmdcopy.decode(conn.info.encoding)
+ elif isinstance(cmdcopy, sql.Composable):
+ cmdcopy = cmdcopy.as_string(conn)
+
+ L.append(cmdcopy)
+ return _orig_exec_command(command, *args, **kwargs)
+
+ monkeypatch.setattr(conn, "_exec_command", _exec_command)
+ return L
+
+
+class ListPopAll(list): # type: ignore[type-arg]
+ """A list, with a popall() method."""
+
+ def popall(self):
+ out = self[:]
+ del self[:]
+ return out
+
+
+def check_connection_version(node):
+ try:
+ pg_version
+ except NameError:
+ # First connection creation failed. Let the tests fail.
+ pytest.fail("server version not available")
+
+ for mark in node.iter_markers():
+ if mark.name == "pg":
+ assert len(mark.args) == 1
+ msg = check_postgres_version(pg_version, mark.args[0])
+ if msg:
+ pytest.skip(msg)
+
+ elif mark.name in ("crdb", "crdb_skip"):
+ from .fix_crdb import check_crdb_version
+
+ msg = check_crdb_version(crdb_version, mark)
+ if msg:
+ pytest.skip(msg)
+
+
+@pytest.fixture
+def hstore(svcconn):
+ try:
+ with svcconn.transaction():
+ svcconn.execute("create extension if not exists hstore")
+ except psycopg.Error as e:
+ pytest.skip(str(e))
+
+
+@cache
+def warm_up_database(dsn: str) -> None:
+ """Connect to the database before returning a connection.
+
+ In the CI sometimes, the first test fails with a timeout, probably because
+ the server hasn't started yet. Absorb the delay before the test.
+
+ In case of error, abort the test run entirely, to avoid failing downstream
+ hundreds of times.
+ """
+ global pg_version, crdb_version
+
+ try:
+ with psycopg.connect(dsn, connect_timeout=10) as conn:
+ conn.execute("select 1")
+
+ pg_version = conn.info.server_version
+
+ crdb_version = None
+ param = conn.info.parameter_status("crdb_version")
+ if param:
+ from psycopg.crdb import CrdbConnectionInfo
+
+ crdb_version = CrdbConnectionInfo.parse_crdb_version(param)
+ except Exception as exc:
+ pytest.exit(f"failed to connect to the test database: {exc}")
diff --git a/tests/fix_faker.py b/tests/fix_faker.py
new file mode 100644
index 0000000..5289d8f
--- /dev/null
+++ b/tests/fix_faker.py
@@ -0,0 +1,868 @@
+import datetime as dt
+import importlib
+import ipaddress
+from math import isnan
+from uuid import UUID
+from random import choice, random, randrange
+from typing import Any, List, Set, Tuple, Union
+from decimal import Decimal
+from contextlib import contextmanager, asynccontextmanager
+
+import pytest
+
+import psycopg
+from psycopg import sql
+from psycopg.adapt import PyFormat
+from psycopg._compat import Deque
+from psycopg.types.range import Range
+from psycopg.types.json import Json, Jsonb
+from psycopg.types.numeric import Int4, Int8
+from psycopg.types.multirange import Multirange
+
+
+@pytest.fixture
+def faker(conn):
+ return Faker(conn)
+
+
+class Faker:
+ """
+ An object to generate random records.
+ """
+
+ json_max_level = 3
+ json_max_length = 10
+ str_max_length = 100
+ list_max_length = 20
+ tuple_max_length = 15
+
+ def __init__(self, connection):
+ self.conn = connection
+ self.format = PyFormat.BINARY
+ self.records = []
+
+ self._schema = None
+ self._types = None
+ self._types_names = None
+ self._makers = {}
+ self.table_name = sql.Identifier("fake_table")
+
+ @property
+ def schema(self):
+ if not self._schema:
+ self.schema = self.choose_schema()
+ return self._schema
+
+ @schema.setter
+ def schema(self, schema):
+ self._schema = schema
+ self._types_names = None
+
+ @property
+ def fields_names(self):
+ return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
+
+ @property
+ def types(self):
+ if not self._types:
+
+ def key(cls: type) -> str:
+ return cls.__name__
+
+ self._types = sorted(self.get_supported_types(), key=key)
+ return self._types
+
+ @property
+ def types_names_sql(self):
+ if self._types_names:
+ return self._types_names
+
+ record = self.make_record(nulls=0)
+ tx = psycopg.adapt.Transformer(self.conn)
+ types = [
+ self._get_type_name(tx, schema, value)
+ for schema, value in zip(self.schema, record)
+ ]
+ self._types_names = types
+ return types
+
+ @property
+ def types_names(self):
+ types = [t.as_string(self.conn).replace('"', "") for t in self.types_names_sql]
+ return types
+
+ def _get_type_name(self, tx, schema, value):
+ # Special case it as it is passed as unknown so is returned as text
+ if schema == (list, str):
+ return sql.SQL("text[]")
+
+ registry = self.conn.adapters.types
+ dumper = tx.get_dumper(value, self.format)
+ dumper.dump(value) # load the oid if it's dynamic (e.g. array)
+ info = registry.get(dumper.oid) or registry.get("text")
+ if dumper.oid == info.array_oid:
+ return sql.SQL("{}[]").format(sql.Identifier(info.name))
+ else:
+ return sql.Identifier(info.name)
+
+ @property
+ def drop_stmt(self):
+ return sql.SQL("drop table if exists {}").format(self.table_name)
+
+ @property
+ def create_stmt(self):
+ field_values = []
+ for name, type in zip(self.fields_names, self.types_names_sql):
+ field_values.append(sql.SQL("{} {}").format(name, type))
+
+ fields = sql.SQL(", ").join(field_values)
+ return sql.SQL("create table {table} (id serial primary key, {fields})").format(
+ table=self.table_name, fields=fields
+ )
+
+ @property
+ def insert_stmt(self):
+ phs = [sql.Placeholder(format=self.format) for i in range(len(self.schema))]
+ return sql.SQL("insert into {} ({}) values ({})").format(
+ self.table_name,
+ sql.SQL(", ").join(self.fields_names),
+ sql.SQL(", ").join(phs),
+ )
+
+ @property
+ def select_stmt(self):
+ fields = sql.SQL(", ").join(self.fields_names)
+ return sql.SQL("select {} from {} order by id").format(fields, self.table_name)
+
+ @contextmanager
+ def find_insert_problem(self, conn):
+ """Context manager to help finding a problematic value."""
+ try:
+ with conn.transaction():
+ yield
+ except psycopg.DatabaseError:
+ cur = conn.cursor()
+ # Repeat insert one field at time, until finding the wrong one
+ cur.execute(self.drop_stmt)
+ cur.execute(self.create_stmt)
+ for i, rec in enumerate(self.records):
+ for j, val in enumerate(rec):
+ try:
+ cur.execute(self._insert_field_stmt(j), (val,))
+ except psycopg.DatabaseError as e:
+ r = repr(val)
+ if len(r) > 200:
+ r = f"{r[:200]}... ({len(r)} chars)"
+ raise Exception(
+ f"value {r!r} at record {i} column0 {j} failed insert: {e}"
+ ) from None
+
+ # just in case, but hopefully we should have triggered the problem
+ raise
+
+ @asynccontextmanager
+ async def find_insert_problem_async(self, aconn):
+ try:
+ async with aconn.transaction():
+ yield
+ except psycopg.DatabaseError:
+ acur = aconn.cursor()
+ # Repeat insert one field at time, until finding the wrong one
+ await acur.execute(self.drop_stmt)
+ await acur.execute(self.create_stmt)
+ for i, rec in enumerate(self.records):
+ for j, val in enumerate(rec):
+ try:
+ await acur.execute(self._insert_field_stmt(j), (val,))
+ except psycopg.DatabaseError as e:
+ r = repr(val)
+ if len(r) > 200:
+ r = f"{r[:200]}... ({len(r)} chars)"
+ raise Exception(
+ f"value {r!r} at record {i} column0 {j} failed insert: {e}"
+ ) from None
+
+ # just in case, but hopefully we should have triggered the problem
+ raise
+
+ def _insert_field_stmt(self, i):
+ ph = sql.Placeholder(format=self.format)
+ return sql.SQL("insert into {} ({}) values ({})").format(
+ self.table_name, self.fields_names[i], ph
+ )
+
+ def choose_schema(self, ncols=20):
+ schema: List[Union[Tuple[type, ...], type]] = []
+ while len(schema) < ncols:
+ s = self.make_schema(choice(self.types))
+ if s is not None:
+ schema.append(s)
+ self.schema = schema
+ return schema
+
+ def make_records(self, nrecords):
+ self.records = [self.make_record(nulls=0.05) for i in range(nrecords)]
+
+ def make_record(self, nulls=0):
+ if not nulls:
+ return tuple(self.example(spec) for spec in self.schema)
+ else:
+ return tuple(
+ self.make(spec) if random() > nulls else None for spec in self.schema
+ )
+
+ def assert_record(self, got, want):
+ for spec, g, w in zip(self.schema, got, want):
+ if g is None and w is None:
+ continue
+ m = self.get_matcher(spec)
+ m(spec, g, w)
+
+ def get_supported_types(self) -> Set[type]:
+ dumpers = self.conn.adapters._dumpers[self.format]
+ rv = set()
+ for cls in dumpers.keys():
+ if isinstance(cls, str):
+ cls = deep_import(cls)
+ if issubclass(cls, Multirange) and self.conn.info.server_version < 140000:
+ continue
+
+ rv.add(cls)
+
+ # check all the types are handled
+ for cls in rv:
+ self.get_maker(cls)
+
+ return rv
+
+ def make_schema(self, cls: type) -> Union[Tuple[type, ...], type, None]:
+ """Create a schema spec from a Python type.
+
+ A schema specifies what Postgres type to generate when a Python type
+ maps to more than one (e.g. tuple -> composite, list -> array[],
+ datetime -> timestamp[tz]).
+
+ A schema for a type is represented by a tuple (type, ...) which the
+ matching make_*() method can interpret, or just type if the type
+ doesn't require further specification.
+
+ A `None` means that the type is not supported.
+ """
+ meth = self._get_method("schema", cls)
+ return meth(cls) if meth else cls
+
+ def get_maker(self, spec):
+ cls = spec if isinstance(spec, type) else spec[0]
+
+ try:
+ return self._makers[cls]
+ except KeyError:
+ pass
+
+ meth = self._get_method("make", cls)
+ if meth:
+ self._makers[cls] = meth
+ return meth
+ else:
+ raise NotImplementedError(f"cannot make fake objects of class {cls}")
+
+ def get_matcher(self, spec):
+ cls = spec if isinstance(spec, type) else spec[0]
+ meth = self._get_method("match", cls)
+ return meth if meth else self.match_any
+
+ def _get_method(self, prefix, cls):
+ name = cls.__name__
+ if cls.__module__ != "builtins":
+ name = f"{cls.__module__}.{name}"
+
+ parts = name.split(".")
+ for i in range(len(parts)):
+ mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}"
+ meth = getattr(self, mname, None)
+ if meth:
+ return meth
+
+ return None
+
+ def make(self, spec):
+ # spec can be a type or a tuple (type, options)
+ return self.get_maker(spec)(spec)
+
+ def example(self, spec):
+ # A good representative of the object - no degenerate case
+ cls = spec if isinstance(spec, type) else spec[0]
+ meth = self._get_method("example", cls)
+ if meth:
+ return meth(spec)
+ else:
+ return self.make(spec)
+
+ def match_any(self, spec, got, want):
+ assert got == want
+
+ # methods to generate samples of specific types
+
+ def make_Binary(self, spec):
+ return self.make_bytes(spec)
+
+ def match_Binary(self, spec, got, want):
+ return want.obj == got
+
+ def make_bool(self, spec):
+ return choice((True, False))
+
+ def make_bytearray(self, spec):
+ return self.make_bytes(spec)
+
+ def make_bytes(self, spec):
+ length = randrange(self.str_max_length)
+ return spec(bytes([randrange(256) for i in range(length)]))
+
+ def make_date(self, spec):
+ day = randrange(dt.date.max.toordinal())
+ return dt.date.fromordinal(day + 1)
+
+ def schema_datetime(self, cls):
+ return self.schema_time(cls)
+
+ def make_datetime(self, spec):
+ # Add a day because with timezone we might go BC
+ dtmin = dt.datetime.min + dt.timedelta(days=1)
+ delta = dt.datetime.max - dtmin
+ micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000)
+ rv = dtmin + dt.timedelta(microseconds=micros)
+ if spec[1]:
+ rv = rv.replace(tzinfo=self._make_tz(spec))
+ return rv
+
+ def match_datetime(self, spec, got, want):
+ # Comparisons with different timezones is unreliable: certain pairs
+ # are reported different but their delta is 0
+ # https://bugs.python.org/issue45347
+ assert not (got - want)
+
+ def make_Decimal(self, spec):
+ if random() >= 0.99:
+ return Decimal(choice(self._decimal_special_values()))
+
+ sign = choice("+-")
+ num = choice(["0.zd", "d", "d.d"])
+ while "z" in num:
+ ndigits = randrange(1, 20)
+ num = num.replace("z", "0" * ndigits, 1)
+ while "d" in num:
+ ndigits = randrange(1, 20)
+ num = num.replace(
+ "d", "".join([str(randrange(10)) for i in range(ndigits)]), 1
+ )
+ expsign = choice(["e+", "e-", ""])
+ exp = randrange(20) if expsign else ""
+ rv = Decimal(f"{sign}{num}{expsign}{exp}")
+ return rv
+
+ def match_Decimal(self, spec, got, want):
+ if got is not None and got.is_nan():
+ assert want.is_nan()
+ else:
+ assert got == want
+
+ def _decimal_special_values(self):
+ values = ["NaN", "sNaN"]
+
+ if self.conn.info.vendor == "PostgreSQL":
+ if self.conn.info.server_version >= 140000:
+ values.extend(["Inf", "-Inf"])
+ elif self.conn.info.vendor == "CockroachDB":
+ if self.conn.info.server_version >= 220100:
+ values.extend(["Inf", "-Inf"])
+ else:
+ pytest.fail(f"unexpected vendor: {self.conn.info.vendor}")
+
+ return values
+
+ def schema_Enum(self, cls):
+ # TODO: can't fake those as we would need to create temporary types
+ return None
+
+ def make_Enum(self, spec):
+ return None
+
+ def make_float(self, spec, double=True):
+ if random() <= 0.99:
+ # These exponents should generate no inf
+ return float(
+ f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}"
+ if double
+ else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
+ )
+ else:
+ return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan")))
+
+ def match_float(self, spec, got, want, approx=False, rel=None):
+ if got is not None and isnan(got):
+ assert isnan(want)
+ else:
+ if approx or self._server_rounds():
+ assert got == pytest.approx(want, rel=rel)
+ else:
+ assert got == want
+
+ def _server_rounds(self):
+ """Return True if the connected server perform float rounding"""
+ if self.conn.info.vendor == "CockroachDB":
+ return True
+ else:
+ # Versions older than 12 make some rounding. e.g. in Postgres 10.4
+ # select '-1.409006204063909e+112'::float8
+ # -> -1.40900620406391e+112
+ return self.conn.info.server_version < 120000
+
+ def make_Float4(self, spec):
+ return spec(self.make_float(spec, double=False))
+
+ def match_Float4(self, spec, got, want):
+ self.match_float(spec, got, want, approx=True, rel=1e-5)
+
+ def make_Float8(self, spec):
+ return spec(self.make_float(spec))
+
+ match_Float8 = match_float
+
+ def make_int(self, spec):
+ return randrange(-(1 << 90), 1 << 90)
+
+ def make_Int2(self, spec):
+ return spec(randrange(-(1 << 15), 1 << 15))
+
+ def make_Int4(self, spec):
+ return spec(randrange(-(1 << 31), 1 << 31))
+
+ def make_Int8(self, spec):
+ return spec(randrange(-(1 << 63), 1 << 63))
+
+ def make_IntNumeric(self, spec):
+ return spec(randrange(-(1 << 100), 1 << 100))
+
+ def make_IPv4Address(self, spec):
+ return ipaddress.IPv4Address(bytes(randrange(256) for _ in range(4)))
+
+ def make_IPv4Interface(self, spec):
+ prefix = randrange(32)
+ return ipaddress.IPv4Interface(
+ (bytes(randrange(256) for _ in range(4)), prefix)
+ )
+
+ def make_IPv4Network(self, spec):
+ return self.make_IPv4Interface(spec).network
+
+ def make_IPv6Address(self, spec):
+ return ipaddress.IPv6Address(bytes(randrange(256) for _ in range(16)))
+
+ def make_IPv6Interface(self, spec):
+ prefix = randrange(128)
+ return ipaddress.IPv6Interface(
+ (bytes(randrange(256) for _ in range(16)), prefix)
+ )
+
+ def make_IPv6Network(self, spec):
+ return self.make_IPv6Interface(spec).network
+
+ def make_Json(self, spec):
+ return spec(self._make_json())
+
+ def match_Json(self, spec, got, want):
+ if want is not None:
+ want = want.obj
+ assert got == want
+
+ def make_Jsonb(self, spec):
+ return spec(self._make_json())
+
+ def match_Jsonb(self, spec, got, want):
+ self.match_Json(spec, got, want)
+
+ def make_JsonFloat(self, spec):
+ # A float limited to what json accepts
+ # this exponent should generate no inf
+ return float(f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}")
+
+ def schema_list(self, cls):
+ while True:
+ scls = choice(self.types)
+ if scls is cls:
+ continue
+ if scls is float:
+ # TODO: float lists are currently adapted as decimal.
+ # There may be rounding errors or problems with inf.
+ continue
+
+ # CRDB doesn't support arrays of json
+ # https://github.com/cockroachdb/cockroach/issues/23468
+ if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb):
+ continue
+
+ schema = self.make_schema(scls)
+ if schema is not None:
+ break
+
+ return (cls, schema)
+
+ def make_list(self, spec):
+ # don't make empty lists because they regularly fail cast
+ length = randrange(1, self.list_max_length)
+ spec = spec[1]
+ while True:
+ rv = [self.make(spec) for i in range(length)]
+
+ # TODO multirange lists fail binary dump if the last element is
+ # empty and there is no type annotation. See xfail in
+ # test_multirange::test_dump_builtin_array
+ if rv and isinstance(rv[-1], Multirange) and not rv[-1]:
+ continue
+
+ return rv
+
+ def example_list(self, spec):
+ return [self.example(spec[1])]
+
+ def match_list(self, spec, got, want):
+ assert len(got) == len(want)
+ m = self.get_matcher(spec[1])
+ for g, w in zip(got, want):
+ m(spec[1], g, w)
+
+ def make_memoryview(self, spec):
+ return self.make_bytes(spec)
+
+ def schema_Multirange(self, cls):
+ return self.schema_Range(cls)
+
+ def make_Multirange(self, spec, length=None, **kwargs):
+ if length is None:
+ length = randrange(0, self.list_max_length)
+
+ def overlap(r1, r2):
+ l1, u1 = r1.lower, r1.upper
+ l2, u2 = r2.lower, r2.upper
+ if l1 is None and l2 is None:
+ return True
+ elif l1 is None:
+ l1 = l2
+ elif l2 is None:
+ l2 = l1
+
+ if u1 is None and u2 is None:
+ return True
+ elif u1 is None:
+ u1 = u2
+ elif u2 is None:
+ u2 = u1
+
+ return l1 <= u2 and l2 <= u1
+
+ out: List[Range[Any]] = []
+ for i in range(length):
+ r = self.make_Range((Range, spec[1]), **kwargs)
+ if r.isempty:
+ continue
+ for r2 in out:
+ if overlap(r, r2):
+ insert = False
+ break
+ else:
+ insert = True
+ if insert:
+ out.append(r) # alternatively, we could merge
+
+ return spec[0](sorted(out))
+
+ def example_Multirange(self, spec):
+ return self.make_Multirange(spec, length=1, empty_chance=0, no_bound_chance=0)
+
+ def make_Int4Multirange(self, spec):
+ return self.make_Multirange((spec, Int4))
+
+ def make_Int8Multirange(self, spec):
+ return self.make_Multirange((spec, Int8))
+
+ def make_NumericMultirange(self, spec):
+ return self.make_Multirange((spec, Decimal))
+
+ def make_DateMultirange(self, spec):
+ return self.make_Multirange((spec, dt.date))
+
+ def make_TimestampMultirange(self, spec):
+ return self.make_Multirange((spec, (dt.datetime, False)))
+
+ def make_TimestamptzMultirange(self, spec):
+ return self.make_Multirange((spec, (dt.datetime, True)))
+
+ def match_Multirange(self, spec, got, want):
+ assert len(got) == len(want)
+ for ig, iw in zip(got, want):
+ self.match_Range(spec, ig, iw)
+
+ def match_Int4Multirange(self, spec, got, want):
+ return self.match_Multirange((spec, Int4), got, want)
+
+ def match_Int8Multirange(self, spec, got, want):
+ return self.match_Multirange((spec, Int8), got, want)
+
+ def match_NumericMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, Decimal), got, want)
+
+ def match_DateMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, dt.date), got, want)
+
+ def match_TimestampMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, (dt.datetime, False)), got, want)
+
+ def match_TimestamptzMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, (dt.datetime, True)), got, want)
+
+ def schema_NoneType(self, cls):
+ return None
+
+ def make_NoneType(self, spec):
+ return None
+
+ def make_Oid(self, spec):
+ return spec(randrange(1 << 32))
+
+ def schema_Range(self, cls):
+ subtypes = [
+ Decimal,
+ Int4,
+ Int8,
+ dt.date,
+ (dt.datetime, True),
+ (dt.datetime, False),
+ ]
+
+ return (cls, choice(subtypes))
+
+ def make_Range(self, spec, empty_chance=0.02, no_bound_chance=0.05):
+ # TODO: drop format check after fixing binary dumping of empty ranges
+ # (an array starting with an empty range will get the wrong type currently)
+ if (
+ random() < empty_chance
+ and spec[0] is Range
+ and self.format == PyFormat.TEXT
+ ):
+ return spec[0](empty=True)
+
+ while True:
+ bounds: List[Union[Any, None]] = []
+ while len(bounds) < 2:
+ if random() < no_bound_chance:
+ bounds.append(None)
+ continue
+
+ val = self.make(spec[1])
+ # NaN are allowed in a range, but comparison in Python get tricky.
+ if spec[1] is Decimal and val.is_nan():
+ continue
+
+ bounds.append(val)
+
+ if bounds[0] is not None and bounds[1] is not None:
+ if bounds[0] == bounds[1]:
+ # It would come out empty
+ continue
+
+ if bounds[0] > bounds[1]:
+ bounds.reverse()
+
+ # avoid generating ranges with no type info if dumping in binary
+ # TODO: lift this limitation after test_copy_in_empty xfail is fixed
+ if spec[0] is Range and self.format == PyFormat.BINARY:
+ if bounds[0] is bounds[1] is None:
+ continue
+
+ break
+
+ r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])"))
+ return r
+
+ def example_Range(self, spec):
+ return self.make_Range(spec, empty_chance=0, no_bound_chance=0)
+
+ def make_Int4Range(self, spec):
+ return self.make_Range((spec, Int4))
+
+ def make_Int8Range(self, spec):
+ return self.make_Range((spec, Int8))
+
+ def make_NumericRange(self, spec):
+ return self.make_Range((spec, Decimal))
+
+ def make_DateRange(self, spec):
+ return self.make_Range((spec, dt.date))
+
+ def make_TimestampRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, False)))
+
+ def make_TimestamptzRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, True)))
+
+ def match_Range(self, spec, got, want):
+ # normalise the bounds of unbounded ranges
+ if want.lower is None and want.lower_inc:
+ want = type(want)(want.lower, want.upper, "(" + want.bounds[1])
+ if want.upper is None and want.upper_inc:
+ want = type(want)(want.lower, want.upper, want.bounds[0] + ")")
+
+ # Normalise discrete ranges
+ unit: Union[dt.timedelta, int, None]
+ if spec[1] is dt.date:
+ unit = dt.timedelta(days=1)
+ elif type(spec[1]) is type and issubclass(spec[1], int):
+ unit = 1
+ else:
+ unit = None
+
+ if unit is not None:
+ if want.lower is not None and not want.lower_inc:
+ want = type(want)(want.lower + unit, want.upper, "[" + want.bounds[1])
+ if want.upper_inc:
+ want = type(want)(want.lower, want.upper + unit, want.bounds[0] + ")")
+
+ if spec[1] == (dt.datetime, True) and not want.isempty:
+ # work around https://bugs.python.org/issue45347
+ def fix_dt(x):
+ return x.astimezone(dt.timezone.utc) if x is not None else None
+
+ def fix_range(r):
+ return type(r)(fix_dt(r.lower), fix_dt(r.upper), r.bounds)
+
+ want = fix_range(want)
+ got = fix_range(got)
+
+ assert got == want
+
+ def match_Int4Range(self, spec, got, want):
+ return self.match_Range((spec, Int4), got, want)
+
+ def match_Int8Range(self, spec, got, want):
+ return self.match_Range((spec, Int8), got, want)
+
+ def match_NumericRange(self, spec, got, want):
+ return self.match_Range((spec, Decimal), got, want)
+
+ def match_DateRange(self, spec, got, want):
+ return self.match_Range((spec, dt.date), got, want)
+
+ def match_TimestampRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, False)), got, want)
+
+ def match_TimestamptzRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, True)), got, want)
+
+ def make_str(self, spec, length=0):
+ if not length:
+ length = randrange(self.str_max_length)
+
+ rv: List[int] = []
+ while len(rv) < length:
+ c = randrange(1, 128) if random() < 0.5 else randrange(1, 0x110000)
+ if not (0xD800 <= c <= 0xDBFF or 0xDC00 <= c <= 0xDFFF):
+ rv.append(c)
+
+ return "".join(map(chr, rv))
+
+ def schema_time(self, cls):
+ # Choose timezone yes/no
+ return (cls, choice([True, False]))
+
+ def make_time(self, spec):
+ val = randrange(24 * 60 * 60 * 1_000_000)
+ val, ms = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+ tz = self._make_tz(spec) if spec[1] else None
+ return dt.time(h, m, s, ms, tz)
+
+ CRDB_TIMEDELTA_MAX = dt.timedelta(days=1281239)
+
+ def make_timedelta(self, spec):
+ if self.conn.info.vendor == "CockroachDB":
+ rng = [-self.CRDB_TIMEDELTA_MAX, self.CRDB_TIMEDELTA_MAX]
+ else:
+ rng = [dt.timedelta.min, dt.timedelta.max]
+
+ return choice(rng) * random()
+
+ def schema_tuple(self, cls):
+ # TODO: this is a complicated matter as it would involve creating
+ # temporary composite types.
+ # length = randrange(1, self.tuple_max_length)
+ # return (cls, self.make_random_schema(ncols=length))
+ return None
+
+ def make_tuple(self, spec):
+ return tuple(self.make(s) for s in spec[1])
+
+ def match_tuple(self, spec, got, want):
+ assert len(got) == len(want) == len(spec[1])
+ for g, w, s in zip(got, want, spec):
+ if g is None or w is None:
+ assert g is w
+ else:
+ m = self.get_matcher(s)
+ m(s, g, w)
+
+ def make_UUID(self, spec):
+ return UUID(bytes=bytes([randrange(256) for i in range(16)]))
+
+ def _make_json(self, container_chance=0.66):
+ rec_types = [list, dict]
+ scal_types = [type(None), int, JsonFloat, bool, str]
+ if random() < container_chance:
+ cls = choice(rec_types)
+ if cls is list:
+ return [
+ self._make_json(container_chance=container_chance / 2.0)
+ for i in range(randrange(self.json_max_length))
+ ]
+ elif cls is dict:
+ return {
+ self.make_str(str, 15): self._make_json(
+ container_chance=container_chance / 2.0
+ )
+ for i in range(randrange(self.json_max_length))
+ }
+ else:
+ assert False, f"unknown rec type: {cls}"
+
+ else:
+ cls = choice(scal_types) # type: ignore[assignment]
+ return self.make(cls)
+
+ def _make_tz(self, spec):
+ minutes = randrange(-12 * 60, 12 * 60 + 1)
+ return dt.timezone(dt.timedelta(minutes=minutes))
+
+
+class JsonFloat:
+ pass
+
+
+def deep_import(name):
+ parts = Deque(name.split("."))
+ seen = []
+ if not parts:
+ raise ValueError("name must be a dot-separated name")
+
+ seen.append(parts.popleft())
+ thing = importlib.import_module(seen[-1])
+ while parts:
+ attr = parts.popleft()
+ seen.append(attr)
+
+ if hasattr(thing, attr):
+ thing = getattr(thing, attr)
+ else:
+ thing = importlib.import_module(".".join(seen))
+
+ return thing
diff --git a/tests/fix_mypy.py b/tests/fix_mypy.py
new file mode 100644
index 0000000..b860a32
--- /dev/null
+++ b/tests/fix_mypy.py
@@ -0,0 +1,54 @@
+import re
+import subprocess as sp
+
+import pytest
+
+
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers",
+ "mypy: the test uses mypy (the marker is set automatically"
+ " on tests using the fixture)",
+ )
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ if "mypy" in item.fixturenames:
+ # add a mypy tag so we can address these tests only
+ item.add_marker(pytest.mark.mypy)
+
+ # All the tests using mypy are slow
+ item.add_marker(pytest.mark.slow)
+
+
+@pytest.fixture(scope="session")
+def mypy(tmp_path_factory):
+ cache_dir = tmp_path_factory.mktemp(basename="mypy_cache")
+ src_dir = tmp_path_factory.mktemp("source")
+
+ class MypyRunner:
+ def run_on_file(self, filename):
+ cmdline = f"""
+ mypy
+ --strict
+ --show-error-codes --no-color-output --no-error-summary
+ --config-file= --cache-dir={cache_dir}
+ """.split()
+ cmdline.append(filename)
+ return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT)
+
+ def run_on_source(self, source):
+ fn = src_dir / "tmp.py"
+ with fn.open("w") as f:
+ f.write(source)
+
+ return self.run_on_file(str(fn))
+
+ def get_revealed(self, line):
+ """return the type from an output of reveal_type"""
+ return re.sub(
+ r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line
+ ).replace("*", "")
+
+ return MypyRunner()
diff --git a/tests/fix_pq.py b/tests/fix_pq.py
new file mode 100644
index 0000000..6811a26
--- /dev/null
+++ b/tests/fix_pq.py
@@ -0,0 +1,141 @@
+import os
+import sys
+import ctypes
+from typing import Iterator, List, NamedTuple
+from tempfile import TemporaryFile
+
+import pytest
+
+from .utils import check_libpq_version
+
+try:
+ from psycopg import pq
+except ImportError:
+ pq = None # type: ignore
+
+
+def pytest_report_header(config):
+ try:
+ from psycopg import pq
+ except ImportError:
+ return []
+
+ return [
+ f"libpq wrapper implementation: {pq.__impl__}",
+ f"libpq used: {pq.version()}",
+ f"libpq compiled: {pq.__build_version__}",
+ ]
+
+
+def pytest_configure(config):
+ # register libpq marker
+ config.addinivalue_line(
+ "markers",
+ "libpq(version_expr): run the test only with matching libpq"
+ " (e.g. '>= 10', '< 9.6')",
+ )
+
+
+def pytest_runtest_setup(item):
+ for m in item.iter_markers(name="libpq"):
+ assert len(m.args) == 1
+ msg = check_libpq_version(pq.version(), m.args[0])
+ if msg:
+ pytest.skip(msg)
+
+
+@pytest.fixture
+def libpq():
+ """Return a ctypes wrapper to access the libpq."""
+ try:
+ from psycopg.pq.misc import find_libpq_full_path
+
+ # Not available when testing the binary package
+ libname = find_libpq_full_path()
+ assert libname, "libpq libname not found"
+ return ctypes.pydll.LoadLibrary(libname)
+ except Exception as e:
+ if pq.__impl__ == "binary":
+ pytest.skip(f"can't load libpq for testing: {e}")
+ else:
+ raise
+
+
+@pytest.fixture
+def setpgenv(monkeypatch):
+ """Replace the PG* env vars with the vars provided."""
+
+ def setpgenv_(env):
+ ks = [k for k in os.environ if k.startswith("PG")]
+ for k in ks:
+ monkeypatch.delenv(k)
+
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+
+ return setpgenv_
+
+
+@pytest.fixture
+def trace(libpq):
+ pqver = pq.__build_version__
+ if pqver < 140000:
+ pytest.skip(f"trace not available on libpq {pqver}")
+ if sys.platform != "linux":
+ pytest.skip(f"trace not available on {sys.platform}")
+
+ yield Tracer()
+
+
+class Tracer:
+ def trace(self, conn):
+ pgconn: "pq.abc.PGconn"
+
+ if hasattr(conn, "exec_"):
+ pgconn = conn
+ elif hasattr(conn, "cursor"):
+ pgconn = conn.pgconn
+ else:
+ raise Exception()
+
+ return TraceLog(pgconn)
+
+
+class TraceLog:
+ def __init__(self, pgconn: "pq.abc.PGconn"):
+ self.pgconn = pgconn
+ self.tempfile = TemporaryFile(buffering=0)
+ pgconn.trace(self.tempfile.fileno())
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS)
+
+ def __del__(self):
+ if self.pgconn.status == pq.ConnStatus.OK:
+ self.pgconn.untrace()
+ self.tempfile.close()
+
+ def __iter__(self) -> "Iterator[TraceEntry]":
+ self.tempfile.seek(0)
+ data = self.tempfile.read()
+ for entry in self._parse_entries(data):
+ yield entry
+
+ def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]":
+ for line in data.splitlines():
+ direction, length, type, *content = line.split(b"\t")
+ yield TraceEntry(
+ direction=direction.decode(),
+ length=int(length.decode()),
+ type=type.decode(),
+ # Note: the items encoding is not very solid: no escaped
+ # backslash, no escaped quotes.
+ # At the moment we don't need a proper parser.
+ content=[content[0]] if content else [],
+ )
+
+
+class TraceEntry(NamedTuple):
+ direction: str
+ length: int
+ type: str
+ content: List[bytes]
diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py
new file mode 100644
index 0000000..e50f5ec
--- /dev/null
+++ b/tests/fix_proxy.py
@@ -0,0 +1,127 @@
+import os
+import time
+import socket
+import logging
+import subprocess as sp
+from shutil import which
+
+import pytest
+
+import psycopg
+from psycopg import conninfo
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ # TODO: there is a race condition on macOS and Windows in the CI:
+ # listen returns before really listening and tests based on 'deaf_port'
+ # fail 50% of the times. Just add the 'proxy' mark on these tests
+ # because they are already skipped in the CI.
+ if "proxy" in item.fixturenames or "deaf_port" in item.fixturenames:
+ item.add_marker(pytest.mark.proxy)
+
+
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers",
+ "proxy: the test uses pproxy (the marker is set automatically"
+ " on tests using the fixture)",
+ )
+
+
+@pytest.fixture
+def proxy(dsn):
+ """Return a proxy to the --test-dsn database"""
+ p = Proxy(dsn)
+ yield p
+ p.stop()
+
+
+@pytest.fixture
+def deaf_port(dsn):
+ """Return a port number with a socket open but not answering"""
+ with socket.socket(socket.AF_INET) as s:
+ s.bind(("", 0))
+ port = s.getsockname()[1]
+ s.listen(0)
+ yield port
+
+
+class Proxy:
+ """
+ Proxy a Postgres service for testing purpose.
+
+ Allow to lose connectivity and restart it using stop/start.
+ """
+
+ def __init__(self, server_dsn):
+ cdict = conninfo.conninfo_to_dict(server_dsn)
+
+ # Get server params
+ host = cdict.get("host") or os.environ.get("PGHOST")
+ self.server_host = host if host and not host.startswith("/") else "localhost"
+ self.server_port = cdict.get("port", "5432")
+
+ # Get client params
+ self.client_host = "localhost"
+ self.client_port = self._get_random_port()
+
+ # Make a connection string to the proxy
+ cdict["host"] = self.client_host
+ cdict["port"] = self.client_port
+ cdict["sslmode"] = "disable" # not supported by the proxy
+ self.client_dsn = conninfo.make_conninfo(**cdict)
+
+ # The running proxy process
+ self.proc = None
+
+ def start(self):
+ if self.proc:
+ logging.info("proxy already started")
+ return
+
+ logging.info("starting proxy")
+ pproxy = which("pproxy")
+ if not pproxy:
+ raise ValueError("pproxy program not found")
+ cmdline = [pproxy, "--reuse"]
+ cmdline.extend(["-l", f"tunnel://:{self.client_port}"])
+ cmdline.extend(["-r", f"tunnel://{self.server_host}:{self.server_port}"])
+
+ self.proc = sp.Popen(cmdline, stdout=sp.DEVNULL)
+ logging.info("proxy started")
+ self._wait_listen()
+
+ # verify that the proxy works
+ try:
+ with psycopg.connect(self.client_dsn):
+ pass
+ except Exception as e:
+ pytest.fail(f"failed to create a working proxy: {e}")
+
+ def stop(self):
+ if not self.proc:
+ return
+
+ logging.info("stopping proxy")
+ self.proc.terminate()
+ self.proc.wait()
+ logging.info("proxy stopped")
+ self.proc = None
+
+ @classmethod
+ def _get_random_port(cls):
+ with socket.socket() as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+ def _wait_listen(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ for i in range(20):
+ if 0 == sock.connect_ex((self.client_host, self.client_port)):
+ break
+ time.sleep(0.1)
+ else:
+ raise ValueError("the proxy didn't start listening in time")
+
+ logging.info("proxy listening")
diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py
new file mode 100644
index 0000000..80e0c62
--- /dev/null
+++ b/tests/fix_psycopg.py
@@ -0,0 +1,98 @@
+from copy import deepcopy
+
+import pytest
+
+
+@pytest.fixture
+def global_adapters():
+ """Restore the global adapters after a test has changed them."""
+ from psycopg import adapters
+
+ dumpers = deepcopy(adapters._dumpers)
+ dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
+ loaders = deepcopy(adapters._loaders)
+ types = list(adapters.types)
+
+ yield None
+
+ adapters._dumpers = dumpers
+ adapters._dumpers_by_oid = dumpers_by_oid
+ adapters._loaders = loaders
+ adapters.types.clear()
+ for t in types:
+ adapters.types.add(t)
+
+
+@pytest.fixture
+@pytest.mark.crdb_skip("2-phase commit")
+def tpc(svcconn):
+ tpc = Tpc(svcconn)
+ tpc.check_tpc()
+ tpc.clear_test_xacts()
+ tpc.make_test_table()
+ yield tpc
+ tpc.clear_test_xacts()
+
+
+class Tpc:
+ """Helper object to test two-phase transactions"""
+
+ def __init__(self, conn):
+ assert conn.autocommit
+ self.conn = conn
+
+ def check_tpc(self):
+ from .fix_crdb import is_crdb, crdb_skip_message
+
+ if is_crdb(self.conn):
+ pytest.skip(crdb_skip_message("2-phase commit"))
+
+ val = int(self.conn.execute("show max_prepared_transactions").fetchone()[0])
+ if not val:
+ pytest.skip("prepared transactions disabled in the database")
+
+ def clear_test_xacts(self):
+ """Rollback all the prepared transaction in the testing db."""
+ from psycopg import sql
+
+ cur = self.conn.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (self.conn.info.dbname,),
+ )
+ gids = [r[0] for r in cur]
+ for gid in gids:
+ self.conn.execute(sql.SQL("rollback prepared {}").format(gid))
+
+ def make_test_table(self):
+ self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
+ self.conn.execute("TRUNCATE test_tpc")
+
+ def count_xacts(self):
+ """Return the number of prepared xacts currently in the test db."""
+ cur = self.conn.execute(
+ """
+ select count(*) from pg_prepared_xacts
+ where database = %s""",
+ (self.conn.info.dbname,),
+ )
+ return cur.fetchone()[0]
+
+ def count_test_records(self):
+ """Return the number of records in the test table."""
+ cur = self.conn.execute("select count(*) from test_tpc")
+ return cur.fetchone()[0]
+
+
+@pytest.fixture(scope="module")
+def generators():
+ """Return the 'generators' module for selected psycopg implementation."""
+ from psycopg import pq
+
+ if pq.__impl__ == "c":
+ from psycopg._cmodule import _psycopg
+
+ return _psycopg
+ else:
+ import psycopg.generators
+
+ return psycopg.generators
diff --git a/tests/pool/__init__.py b/tests/pool/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/pool/__init__.py
diff --git a/tests/pool/fix_pool.py b/tests/pool/fix_pool.py
new file mode 100644
index 0000000..12e4f39
--- /dev/null
+++ b/tests/pool/fix_pool.py
@@ -0,0 +1,12 @@
+import pytest
+
+
+def pytest_configure(config):
+ config.addinivalue_line("markers", "pool: test related to the psycopg_pool package")
+
+
+def pytest_collection_modifyitems(items):
+ # Add the pool markers to all the tests in the pool package
+ for item in items:
+ if "/pool/" in item.nodeid:
+ item.add_marker(pytest.mark.pool)
diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py
new file mode 100644
index 0000000..c0e8060
--- /dev/null
+++ b/tests/pool/test_null_pool.py
@@ -0,0 +1,896 @@
+import logging
+from time import sleep, time
+from threading import Thread, Event
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver # noqa: F401 # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+from .test_pool import delay_connection, ensure_waiting
+
+try:
+ from psycopg_pool import NullConnectionPool
+ from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+ pass
+
+
+def test_defaults(dsn):
+ with NullConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 0
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+def test_min_size_max_size(dsn):
+ with NullConnectionPool(dsn, min_size=0, max_size=2) as p:
+ assert p.min_size == 0
+ assert p.max_size == 2
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ NullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+def test_connection_class(dsn):
+ class MyConn(psycopg.Connection[Any]):
+ pass
+
+ with NullConnectionPool(dsn, connection_class=MyConn) as p:
+ with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+def test_kwargs(dsn):
+ with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+ with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_its_no_pool_at_all(dsn):
+ with NullConnectionPool(dsn, max_size=2) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ with p.connection() as conn:
+ assert conn.info.backend_pid not in (pid1, pid2)
+
+
+def test_context(dsn):
+ with NullConnectionPool(dsn) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.2)
+ with pytest.raises(PoolTimeout):
+ with NullConnectionPool(dsn, num_workers=1) as p:
+ p.wait(0.1)
+
+ with NullConnectionPool(dsn, num_workers=1) as p:
+ p.wait(0.4)
+
+
+def test_wait_closed(dsn):
+ with NullConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(PoolClosed):
+ p.wait()
+
+
+@pytest.mark.slow
+def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(PoolTimeout):
+ with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ p.wait(0.2)
+
+ with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+
+def test_configure(dsn):
+ inits = 0
+
+ def configure(conn):
+ nonlocal inits
+ inits += 1
+ with conn.transaction():
+ conn.execute("set default_transaction_read_only to on")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with p.connection() as conn:
+ assert inits == 1
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+ with p.connection() as conn:
+ assert inits == 2
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+ conn.close()
+
+ with p.connection() as conn:
+ assert inits == 3
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ conn.execute("select 1")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset(dsn):
+ resets = 0
+
+ def setup(conn):
+ with conn.transaction():
+ conn.execute("set timezone to '+1:00'")
+
+ def reset(conn):
+ nonlocal resets
+ resets += 1
+ with conn.transaction():
+ conn.execute("set timezone to utc")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ assert resets == 1
+ with conn.execute("show timezone") as cur:
+ assert cur.fetchone() == ("UTC",)
+ pids.append(conn.info.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ # Queue the worker so it will take the same connection a second time
+ # instead of making a new one.
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ assert resets == 0
+ conn.execute("set timezone to '+2:00'")
+ pids.append(conn.info.backend_pid)
+
+ t.join()
+ p.wait()
+
+ assert resets == 1
+ assert pids[0] == pids[1]
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ conn.execute("reset all")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+def test_no_queue_timeout(deaf_port):
+ with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p:
+ with pytest.raises(PoolTimeout):
+ with p.connection(timeout=1):
+ pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue(dsn):
+ def worker(n):
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ with NullConnectionPool(dsn, max_size=2) as p:
+ p.wait()
+ ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+def test_queue_size(dsn):
+ def worker(t, ev=None):
+ try:
+ with p.connection():
+ if ev:
+ ev.set()
+ sleep(t)
+ except TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+ p.wait()
+ ev = Event()
+ t = Thread(target=worker, args=(0.3, ev))
+ t.start()
+ ev.wait()
+
+ ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout(dsn):
+ def worker(n):
+ t0 = time()
+ try:
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_dead_client(dsn):
+ def worker(i, timeout):
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ results: List[int] = []
+
+ with NullConnectionPool(dsn, max_size=2) as p:
+ ts = [
+ Thread(target=worker, args=(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout_override(dsn):
+ def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_broken_reconnect(dsn):
+ with NullConnectionPool(dsn, max_size=1) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ conn.close()
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert not conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ ).fetchone()
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker(p):
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ def bad_rollback():
+ conn.pgconn.finish()
+ orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ t = Thread(target=worker, args=(p,))
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+def test_close_no_threads(dsn):
+ p = NullConnectionPool(dsn)
+ assert p._sched_runner and p._sched_runner.is_alive()
+ workers = p._workers[:]
+ assert workers
+ for t in workers:
+ assert t.is_alive()
+
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert not t.is_alive()
+
+
+def test_putconn_no_pool(conn_cls, dsn):
+ with NullConnectionPool(dsn) as p:
+ conn = conn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ p.putconn(conn)
+
+ conn.close()
+
+
+def test_putconn_wrong_pool(dsn):
+ with NullConnectionPool(dsn) as p1:
+ with NullConnectionPool(dsn) as p2:
+ conn = p1.getconn()
+ with pytest.raises(ValueError):
+ p2.putconn(conn)
+
+
+@pytest.mark.slow
+def test_del_stop_threads(dsn):
+ p = NullConnectionPool(dsn)
+ assert p._sched_runner is not None
+ ts = [p._sched_runner] + p._workers
+ del p
+ sleep(0.1)
+ for t in ts:
+ assert not t.is_alive()
+
+
+def test_closed_getconn(dsn):
+ p = NullConnectionPool(dsn)
+ assert not p.closed
+ with p.connection():
+ pass
+
+ p.close()
+ assert p.closed
+
+ with pytest.raises(PoolClosed):
+ with p.connection():
+ pass
+
+
+def test_closed_putconn(dsn):
+ p = NullConnectionPool(dsn)
+
+ with p.connection() as conn:
+ pass
+ assert conn.closed
+
+ with p.connection() as conn:
+ p.close()
+ assert conn.closed
+
+
+def test_closed_queue(dsn):
+ def w1():
+ with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+ e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ def w2():
+ try:
+ with p.connection():
+ pass # unexpected
+ except PoolClosed:
+ success.append("w2")
+
+ e1 = Event()
+ e2 = Event()
+
+ p = NullConnectionPool(dsn, max_size=1)
+ p.wait()
+ success: List[str] = []
+
+ t1 = Thread(target=w1)
+ t1.start()
+ # Wait until w1 has received a connection
+ e1.wait()
+
+ t2 = Thread(target=w2)
+ t2.start()
+ # Wait until w2 is in the queue
+ ensure_waiting(p)
+
+ p.close(0)
+
+ # Wait for the workers to finish
+ e2.set()
+ t1.join()
+ t2.join()
+ assert len(success) == 2
+
+
+def test_open_explicit(dsn):
+ p = NullConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(PoolClosed, match="is not open yet"):
+ p.getconn()
+
+ with pytest.raises(PoolClosed):
+ with p.connection():
+ pass
+
+ p.open()
+ try:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+ with pytest.raises(PoolClosed, match="is already closed"):
+ p.getconn()
+
+
+def test_open_context(dsn):
+ p = NullConnectionPool(dsn, open=False)
+ assert p.closed
+
+ with p:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+def test_open_no_op(dsn):
+ p = NullConnectionPool(dsn)
+ try:
+ assert not p.closed
+ p.open()
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+
+def test_reopen(dsn):
+ p = NullConnectionPool(dsn)
+ with p.connection() as conn:
+ conn.execute("select 1")
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ p.open()
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+def test_bad_resize(dsn, min_size, max_size):
+ with NullConnectionPool() as p:
+ with pytest.raises(ValueError):
+ p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_max_lifetime(dsn):
+ pids = []
+
+ def worker(p):
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ sleep(0.1)
+
+ ts = []
+ with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+ for i in range(5):
+ ts.append(Thread(target=worker, args=(p,)))
+ ts[-1].start()
+
+ for t in ts:
+ t.join()
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+def test_check(dsn):
+ with NullConnectionPool(dsn) as p:
+ # No-op
+ p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_measures(dsn):
+ def worker(n):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+
+ with NullConnectionPool(dsn, max_size=4) as p:
+ p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 0
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(3)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ p.wait(2.0)
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_usage(dsn):
+ def worker(n):
+ try:
+ with p.connection(timeout=0.3) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ except PoolTimeout:
+ pass
+
+ with NullConnectionPool(dsn, max_size=3) as p:
+ p.wait(2.0)
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ with p.connection() as conn:
+ conn.close()
+ p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ with NullConnectionPool(proxy.client_dsn, max_size=3) as p:
+ p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 1
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 200 <= stats["connections_ms"] < 300
diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py
new file mode 100644
index 0000000..23a1a52
--- /dev/null
+++ b/tests/pool/test_null_pool_async.py
@@ -0,0 +1,844 @@
+import asyncio
+import logging
+from time import time
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver # noqa: F401 # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import create_task
+from .test_pool_async import delay_connection, ensure_waiting
+
+pytestmark = [pytest.mark.asyncio]
+
+try:
+ from psycopg_pool import AsyncNullConnectionPool # noqa: F401
+ from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+ pass
+
+
+async def test_defaults(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 0
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+async def test_min_size_max_size(dsn):
+ async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p:
+ assert p.min_size == 0
+ assert p.max_size == 2
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+async def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ AsyncNullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+async def test_connection_class(dsn):
+ class MyConn(psycopg.AsyncConnection[Any]):
+ pass
+
+ async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p:
+ async with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+async def test_kwargs(dsn):
+ async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+ async with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_its_no_pool_at_all(dsn):
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ async with p.connection() as conn:
+ assert conn.info.backend_pid not in (pid1, pid2)
+
+
+async def test_context(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.2)
+ with pytest.raises(PoolTimeout):
+ async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+ await p.wait(0.1)
+
+ async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+ await p.wait(0.4)
+
+
+async def test_wait_closed(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(PoolClosed):
+ await p.wait()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(PoolTimeout):
+ async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ await p.wait(0.2)
+
+ async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ await asyncio.sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+
+async def test_configure(dsn):
+ inits = 0
+
+ async def configure(conn):
+ nonlocal inits
+ inits += 1
+ async with conn.transaction():
+ await conn.execute("set default_transaction_read_only to on")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ async with p.connection() as conn:
+ assert inits == 1
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+ async with p.connection() as conn:
+ assert inits == 2
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+ await conn.close()
+
+ async with p.connection() as conn:
+ assert inits == 3
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+async def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ await conn.execute("select 1")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset(dsn):
+ resets = 0
+
+ async def setup(conn):
+ async with conn.transaction():
+ await conn.execute("set timezone to '+1:00'")
+
+ async def reset(conn):
+ nonlocal resets
+ resets += 1
+ async with conn.transaction():
+ await conn.execute("set timezone to utc")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ assert resets == 1
+ cur = await conn.execute("show timezone")
+ assert (await cur.fetchone()) == ("UTC",)
+ pids.append(conn.info.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ # Queue the worker so it will take the same connection a second time
+ # instead of making a new one.
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ assert resets == 0
+ await conn.execute("set timezone to '+2:00'")
+ pids.append(conn.info.backend_pid)
+
+ await asyncio.gather(t)
+ await p.wait()
+
+ assert resets == 1
+ assert pids[0] == pids[1]
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ await conn.execute("reset all")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+async def test_no_queue_timeout(deaf_port):
+ async with AsyncNullConnectionPool(
+ kwargs={"host": "localhost", "port": deaf_port}
+ ) as p:
+ with pytest.raises(PoolTimeout):
+ async with p.connection(timeout=1):
+ pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue(dsn):
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ await p.wait()
+ ts = [create_task(worker(i)) for i in range(6)]
+ await asyncio.gather(*ts)
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_size(dsn):
+ async def worker(t, ev=None):
+ try:
+ async with p.connection():
+ if ev:
+ ev.set()
+ await asyncio.sleep(t)
+ except TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+ await p.wait()
+ ev = asyncio.Event()
+ create_task(worker(0.3, ev))
+ await ev.wait()
+
+ ts = [create_task(worker(0.1)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout(dsn):
+ async def worker(n):
+ t0 = time()
+ try:
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_dead_client(dsn):
+ async def worker(i, timeout):
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ results: List[int] = []
+ ts = [
+ create_task(worker(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout_override(dsn):
+ async def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_broken_reconnect(dsn):
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ await conn.close()
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ cur = await conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ )
+ assert not await cur.fetchone()
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ await conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+async def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ async def bad_rollback():
+ conn.pgconn.finish()
+ await orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+async def test_close_no_tasks(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ assert p._sched_runner and not p._sched_runner.done()
+ assert p._workers
+ workers = p._workers[:]
+ for t in workers:
+ assert not t.done()
+
+ await p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert t.done()
+
+
+async def test_putconn_no_pool(aconn_cls, dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ conn = await aconn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ await p.putconn(conn)
+
+ await conn.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+ async with AsyncNullConnectionPool(dsn) as p1:
+ async with AsyncNullConnectionPool(dsn) as p2:
+ conn = await p1.getconn()
+ with pytest.raises(ValueError):
+ await p2.putconn(conn)
+
+
+async def test_closed_getconn(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ assert not p.closed
+ async with p.connection():
+ pass
+
+ await p.close()
+ assert p.closed
+
+ with pytest.raises(PoolClosed):
+ async with p.connection():
+ pass
+
+
+async def test_closed_putconn(dsn):
+ p = AsyncNullConnectionPool(dsn)
+
+ async with p.connection() as conn:
+ pass
+ assert conn.closed
+
+ async with p.connection() as conn:
+ await p.close()
+ assert conn.closed
+
+
+async def test_closed_queue(dsn):
+ async def w1():
+ async with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ await e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ async def w2():
+ try:
+ async with p.connection():
+ pass # unexpected
+ except PoolClosed:
+ success.append("w2")
+
+ e1 = asyncio.Event()
+ e2 = asyncio.Event()
+
+ p = AsyncNullConnectionPool(dsn, max_size=1)
+ await p.wait()
+ success: List[str] = []
+
+ t1 = create_task(w1())
+ # Wait until w1 has received a connection
+ await e1.wait()
+
+ t2 = create_task(w2())
+ # Wait until w2 is in the queue
+ await ensure_waiting(p)
+ await p.close()
+
+ # Wait for the workers to finish
+ e2.set()
+ await asyncio.gather(t1, t2)
+ assert len(success) == 2
+
+
+async def test_open_explicit(dsn):
+ p = AsyncNullConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(PoolClosed):
+ await p.getconn()
+
+ with pytest.raises(PoolClosed, match="is not open yet"):
+ async with p.connection():
+ pass
+
+ await p.open()
+ try:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+ with pytest.raises(PoolClosed, match="is already closed"):
+ await p.getconn()
+
+
+async def test_open_context(dsn):
+ p = AsyncNullConnectionPool(dsn, open=False)
+ assert p.closed
+
+ async with p:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+async def test_open_no_op(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ try:
+ assert not p.closed
+ await p.open()
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+
+async def test_reopen(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ await p.close()
+ assert p._sched_runner is None
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ await p.open()
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+async def test_bad_resize(dsn, min_size, max_size):
+ async with AsyncNullConnectionPool() as p:
+ with pytest.raises(ValueError):
+ await p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_max_lifetime(dsn):
+ pids: List[int] = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ await asyncio.sleep(0.1)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+ ts = [create_task(worker()) for i in range(5)]
+ await asyncio.gather(*ts)
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+async def test_check(dsn):
+ # no.op
+ async with AsyncNullConnectionPool(dsn) as p:
+ await p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_measures(dsn):
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+
+ async with AsyncNullConnectionPool(dsn, max_size=4) as p:
+ await p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 0
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ ts = [create_task(worker(i)) for i in range(3)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ await p.wait(2.0)
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_usage(dsn):
+ async def worker(n):
+ try:
+ async with p.connection(timeout=0.3) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ except PoolTimeout:
+ pass
+
+ async with AsyncNullConnectionPool(dsn, max_size=3) as p:
+ await p.wait(2.0)
+
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.gather(*ts)
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ async with p.connection() as conn:
+ await conn.close()
+ await p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ async with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+async def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p:
+ await p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 1
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 200 <= stats["connections_ms"] < 300
diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py
new file mode 100644
index 0000000..30c790b
--- /dev/null
+++ b/tests/pool/test_pool.py
@@ -0,0 +1,1265 @@
+import logging
+import weakref
+from time import sleep, time
+from threading import Thread, Event
+from typing import Any, List, Tuple
+
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import Counter
+
+try:
+ import psycopg_pool as pool
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+
+def test_package_version(mypy):
+ cp = mypy.run_on_source(
+ """\
+from psycopg_pool import __version__
+assert __version__
+"""
+ )
+ assert not cp.stdout
+
+
+def test_defaults(dsn):
+ with pool.ConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 4
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)])
+def test_min_size_max_size(dsn, min_size, max_size):
+ with pool.ConnectionPool(dsn, min_size=min_size, max_size=max_size) as p:
+ assert p.min_size == min_size
+ assert p.max_size == max_size if max_size is not None else min_size
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
+def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ pool.ConnectionPool(min_size=min_size, max_size=max_size)
+
+
+def test_connection_class(dsn):
+ class MyConn(psycopg.Connection[Any]):
+ pass
+
+ with pool.ConnectionPool(dsn, connection_class=MyConn, min_size=1) as p:
+ with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+def test_kwargs(dsn):
+ with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, min_size=1) as p:
+ with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_its_really_a_pool(dsn):
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ with p.connection() as conn:
+ assert conn.info.backend_pid in (pid1, pid2)
+
+
+def test_context(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_connection_not_lost(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ with pytest.raises(ZeroDivisionError):
+ with p.connection() as conn:
+ pid = conn.info.backend_pid
+ 1 / 0
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_concurrent_filling(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+
+ def add_time(self, conn):
+ times.append(time() - t0)
+ add_orig(self, conn)
+
+ add_orig = pool.ConnectionPool._add_to_pool
+ monkeypatch.setattr(pool.ConnectionPool, "_add_to_pool", add_time)
+
+ times: List[float] = []
+ t0 = time()
+
+ with pool.ConnectionPool(dsn, min_size=5, num_workers=2) as p:
+ p.wait(1.0)
+ want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
+ assert len(times) == len(want_times)
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.wait(0.3)
+
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.wait(0.5)
+
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=2) as p:
+ p.wait(0.3)
+ p.wait(0.0001) # idempotent
+
+
+def test_wait_closed(dsn):
+ with pool.ConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(pool.PoolClosed):
+ p.wait()
+
+
+@pytest.mark.slow
+def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(pool.PoolTimeout):
+ with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
+ p.wait(0.2)
+
+ with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
+ sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+
+def test_configure(dsn):
+ inits = 0
+
+ def configure(conn):
+ nonlocal inits
+ inits += 1
+ with conn.transaction():
+ conn.execute("set default_transaction_read_only to on")
+
+ with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
+ p.wait()
+ with p.connection() as conn:
+ assert inits == 1
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+ with p.connection() as conn:
+ assert inits == 1
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+ conn.close()
+
+ with p.connection() as conn:
+ assert inits == 2
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ conn.execute("select 1")
+
+ with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+def test_reset(dsn):
+ resets = 0
+
+ def setup(conn):
+ with conn.transaction():
+ conn.execute("set timezone to '+1:00'")
+
+ def reset(conn):
+ nonlocal resets
+ resets += 1
+ with conn.transaction():
+ conn.execute("set timezone to utc")
+
+ with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p:
+ with p.connection() as conn:
+ assert resets == 0
+ conn.execute("set timezone to '+2:00'")
+
+ p.wait()
+ assert resets == 1
+
+ with p.connection() as conn:
+ with conn.execute("show timezone") as cur:
+ assert cur.fetchone() == ("UTC",)
+
+ p.wait()
+ assert resets == 2
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ conn.execute("reset all")
+
+ with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p:
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p:
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue(dsn):
+ def worker(n):
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ p.wait()
+ ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+def test_queue_size(dsn):
+ def worker(t, ev=None):
+ try:
+ with p.connection():
+ if ev:
+ ev.set()
+ sleep(t)
+ except pool.TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ with pool.ConnectionPool(dsn, min_size=1, max_waiting=3) as p:
+ p.wait()
+ ev = Event()
+ t = Thread(target=worker, args=(0.3, ev))
+ t.start()
+ ev.wait()
+
+ ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], pool.TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout(dsn):
+ def worker(n):
+ t0 = time()
+ try:
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_dead_client(dsn):
+ def worker(i, timeout):
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except pool.PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ results: List[int] = []
+
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ ts = [
+ Thread(target=worker, args=(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+ assert len(p._pool) == 2 # no connection was lost
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout_override(dsn):
+ def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_broken_reconnect(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ conn.close()
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+ pid = conn.info.backend_pid
+ conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+ assert not conn2.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ ).fetchone()
+
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+ pid = conn.info.backend_pid
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+
+ def bad_rollback():
+ conn.pgconn.finish()
+ orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+def test_close_no_threads(dsn):
+ p = pool.ConnectionPool(dsn)
+ assert p._sched_runner and p._sched_runner.is_alive()
+ workers = p._workers[:]
+ assert workers
+ for t in workers:
+ assert t.is_alive()
+
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert not t.is_alive()
+
+
+def test_putconn_no_pool(conn_cls, dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = conn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ p.putconn(conn)
+
+ conn.close()
+
+
+def test_putconn_wrong_pool(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p1:
+ with pool.ConnectionPool(dsn, min_size=1) as p2:
+ conn = p1.getconn()
+ with pytest.raises(ValueError):
+ p2.putconn(conn)
+
+
+def test_del_no_warning(dsn, recwarn):
+ p = pool.ConnectionPool(dsn, min_size=2)
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ p.wait()
+ ref = weakref.ref(p)
+ del p
+ assert not ref()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+@pytest.mark.slow
+def test_del_stop_threads(dsn):
+ p = pool.ConnectionPool(dsn)
+ assert p._sched_runner is not None
+ ts = [p._sched_runner] + p._workers
+ del p
+ sleep(0.1)
+ for t in ts:
+ assert not t.is_alive()
+
+
+def test_closed_getconn(dsn):
+ p = pool.ConnectionPool(dsn, min_size=1)
+ assert not p.closed
+ with p.connection():
+ pass
+
+ p.close()
+ assert p.closed
+
+ with pytest.raises(pool.PoolClosed):
+ with p.connection():
+ pass
+
+
+def test_closed_putconn(dsn):
+ p = pool.ConnectionPool(dsn, min_size=1)
+
+ with p.connection() as conn:
+ pass
+ assert not conn.closed
+
+ with p.connection() as conn:
+ p.close()
+ assert conn.closed
+
+
+def test_closed_queue(dsn):
+ def w1():
+ with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+ e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ def w2():
+ try:
+ with p.connection():
+ pass # unexpected
+ except pool.PoolClosed:
+ success.append("w2")
+
+ e1 = Event()
+ e2 = Event()
+
+ p = pool.ConnectionPool(dsn, min_size=1)
+ p.wait()
+ success: List[str] = []
+
+ t1 = Thread(target=w1)
+ t1.start()
+ # Wait until w1 has received a connection
+ e1.wait()
+
+ t2 = Thread(target=w2)
+ t2.start()
+ # Wait until w2 is in the queue
+ ensure_waiting(p)
+
+ p.close(0)
+
+ # Wait for the workers to finish
+ e2.set()
+ t1.join()
+ t2.join()
+ assert len(success) == 2
+
+
+def test_open_explicit(dsn):
+ p = pool.ConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(pool.PoolClosed, match="is not open yet"):
+ p.getconn()
+
+ with pytest.raises(pool.PoolClosed):
+ with p.connection():
+ pass
+
+ p.open()
+ try:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+ with pytest.raises(pool.PoolClosed, match="is already closed"):
+ p.getconn()
+
+
+def test_open_context(dsn):
+ p = pool.ConnectionPool(dsn, open=False)
+ assert p.closed
+
+ with p:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+def test_open_no_op(dsn):
+ p = pool.ConnectionPool(dsn)
+ try:
+ assert not p.closed
+ p.open()
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_open_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ p.open(wait=True, timeout=0.3)
+ finally:
+ p.close()
+
+ p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ p.open(wait=True, timeout=0.5)
+ finally:
+ p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_open_as_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.open(wait=True, timeout=0.3)
+
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.open(wait=True, timeout=0.5)
+
+
+def test_reopen(dsn):
+ p = pool.ConnectionPool(dsn)
+ with p.connection() as conn:
+ conn.execute("select 1")
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ p.open()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize(
+ "min_size, want_times",
+ [
+ (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]),
+ (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]),
+ ],
+)
+def test_grow(dsn, monkeypatch, min_size, want_times):
+ delay_connection(monkeypatch, 0.1)
+
+ def worker(n):
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select 1 from pg_sleep(0.25)")
+ t1 = time()
+ results.append((n, t1 - t0))
+
+ with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p:
+ p.wait(1.0)
+ results: List[Tuple[int, float]] = []
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(len(want_times))]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ times = [item[1] for item in results]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_shrink(dsn, monkeypatch):
+
+ from psycopg_pool.pool import ShrinkPool
+
+ results: List[Tuple[int, int]] = []
+
+ def run_hacked(self, pool):
+ n0 = pool._nconns
+ orig_run(self, pool)
+ n1 = pool._nconns
+ results.append((n0, n1))
+
+ orig_run = ShrinkPool._run
+ monkeypatch.setattr(ShrinkPool, "_run", run_hacked)
+
+ def worker(n):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.1)")
+
+ with pool.ConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p:
+ p.wait(5.0)
+ assert p.max_idle == 0.2
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ sleep(1)
+
+ assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
+
+
+@pytest.mark.slow
+def test_reconnect(proxy, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+ assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
+
+ caplog.clear()
+ proxy.start()
+ with pool.ConnectionPool(proxy.client_dsn, min_size=1) as p:
+ p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ sleep(1.0)
+ proxy.start()
+ p.wait()
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ assert "BAD" in caplog.messages[0]
+ times = [rec.created for rec in caplog.records]
+ assert times[1] - times[0] < 0.05
+ deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)]
+ assert len(deltas) == 3
+ want = 0.1
+ for delta in deltas:
+ assert delta == pytest.approx(want, 0.05), deltas
+ want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_reconnect_failure(proxy):
+ proxy.start()
+
+ t1 = None
+
+ def failed(pool):
+ assert pool.name == "this-one"
+ nonlocal t1
+ t1 = time()
+
+ with pool.ConnectionPool(
+ proxy.client_dsn,
+ name="this-one",
+ min_size=1,
+ reconnect_timeout=1.0,
+ reconnect_failed=failed,
+ ) as p:
+ p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ t0 = time()
+ sleep(1.5)
+ assert t1
+ assert t1 - t0 == pytest.approx(1.0, 0.1)
+ assert p._nconns == 0
+
+ proxy.start()
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select 1")
+ t1 = time()
+ assert t1 - t0 < 0.2
+
+
+@pytest.mark.slow
+def test_reconnect_after_grow_failed(proxy):
+ # Retry reconnection after a failed connection attempt has put the pool
+ # in grow mode. See issue #370.
+ proxy.stop()
+
+ ev = Event()
+
+ def failed(pool):
+ ev.set()
+
+ with pool.ConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ assert ev.wait(timeout=2)
+
+ with pytest.raises(pool.PoolTimeout):
+ with p.connection(timeout=0.5) as conn:
+ pass
+
+ ev.clear()
+ assert ev.wait(timeout=2)
+
+ proxy.start()
+
+ with p.connection(timeout=2) as conn:
+ conn.execute("select 1")
+
+ p.wait(timeout=3.0)
+ assert len(p._pool) == p.min_size == 4
+
+
+@pytest.mark.slow
+def test_refill_on_check(proxy):
+ proxy.start()
+ ev = Event()
+
+ def failed(pool):
+ ev.set()
+
+ with pool.ConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ # The pool is full
+ p.wait(timeout=2)
+
+ # Break all the connection
+ proxy.stop()
+
+ # Checking the pool will empty it
+ p.check()
+ assert ev.wait(timeout=2)
+ assert len(p._pool) == 0
+
+ # Allow to connect again
+ proxy.start()
+
+ # Make sure that check has refilled the pool
+ p.check()
+ p.wait(timeout=2)
+ assert len(p._pool) == 4
+
+
+@pytest.mark.slow
+def test_uniform_use(dsn):
+ with pool.ConnectionPool(dsn, min_size=4) as p:
+ counts = Counter[int]()
+ for i in range(8):
+ with p.connection() as conn:
+ sleep(0.1)
+ counts[id(conn)] += 1
+
+ assert len(counts) == 4
+ assert set(counts.values()) == set([2])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_resize(dsn):
+ def sampler():
+ sleep(0.05) # ensure sampling happens after shrink check
+ while True:
+ sleep(0.2)
+ if p.closed:
+ break
+ size.append(len(p._pool))
+
+ def client(t):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(%s)", [t])
+
+ size: List[int] = []
+
+ with pool.ConnectionPool(dsn, min_size=2, max_idle=0.2) as p:
+ s = Thread(target=sampler)
+ s.start()
+
+ sleep(0.3)
+ c = Thread(target=client, args=(0.4,))
+ c.start()
+
+ sleep(0.2)
+ p.resize(4)
+ assert p.min_size == 4
+ assert p.max_size == 4
+
+ sleep(0.4)
+ p.resize(2)
+ assert p.min_size == 2
+ assert p.max_size == 2
+
+ sleep(0.6)
+
+ s.join()
+ assert size == [2, 1, 3, 4, 3, 2, 2]
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)])
+def test_bad_resize(dsn, min_size, max_size):
+ with pool.ConnectionPool() as p:
+ with pytest.raises(ValueError):
+ p.resize(min_size=min_size, max_size=max_size)
+
+
+def test_jitter():
+ rnds = [pool.ConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)]
+ assert 27 <= min(rnds) <= 28
+ assert 35 < max(rnds) < 36
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_max_lifetime(dsn):
+ with pool.ConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p:
+ sleep(0.1)
+ pids = []
+ for i in range(5):
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ sleep(0.2)
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_check(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ with pool.ConnectionPool(dsn, min_size=4) as p:
+ p.wait(1.0)
+ with p.connection() as conn:
+ pid = conn.info.backend_pid
+
+ p.wait(1.0)
+ pids = set(conn.info.backend_pid for conn in p._pool)
+ assert pid in pids
+ conn.close()
+
+ assert len(caplog.records) == 0
+ p.check()
+ assert len(caplog.records) == 1
+ p.wait(1.0)
+ pids2 = set(conn.info.backend_pid for conn in p._pool)
+ assert len(pids & pids2) == 3
+ assert pid not in pids2
+
+
+def test_check_idle(dsn):
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ p.wait(1.0)
+ p.check()
+ with p.connection() as conn:
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_measures(dsn):
+ def worker(n):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+
+ with pool.ConnectionPool(dsn, min_size=2, max_size=4) as p:
+ p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 2
+ assert stats["pool_available"] == 2
+ assert stats["requests_waiting"] == 0
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(3)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ p.wait(2.0)
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_usage(dsn):
+ def worker(n):
+ try:
+ with p.connection(timeout=0.3) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ except pool.PoolTimeout:
+ pass
+
+ with pool.ConnectionPool(dsn, min_size=3) as p:
+ p.wait(2.0)
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ with p.connection() as conn:
+ conn.close()
+ p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ with pool.ConnectionPool(proxy.client_dsn, min_size=3) as p:
+ p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 3
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 600 <= stats["connections_ms"] < 1200
+
+ proxy.stop()
+ p.check()
+ sleep(0.1)
+ stats = p.get_stats()
+ assert stats["connections_num"] > 3
+ assert stats["connections_errors"] > 0
+ assert stats["connections_lost"] == 3
+
+
+@pytest.mark.slow
+def test_spike(dsn, monkeypatch):
+ # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/
+ # documents/Welcome-To-The-Jungle.md
+ delay_connection(monkeypatch, 0.15)
+
+ def worker():
+ with p.connection():
+ sleep(0.002)
+
+ with pool.ConnectionPool(dsn, min_size=5, max_size=10) as p:
+ p.wait()
+
+ ts = [Thread(target=worker) for i in range(50)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ p.wait()
+
+ assert len(p._pool) < 7
+
+
+def test_debug_deadlock(dsn):
+ # https://github.com/psycopg/psycopg/issues/230
+ logger = logging.getLogger("psycopg")
+ handler = logging.StreamHandler()
+ old_level = logger.level
+ logger.setLevel(logging.DEBUG)
+ handler.setLevel(logging.DEBUG)
+ logger.addHandler(handler)
+ try:
+ with pool.ConnectionPool(dsn, min_size=4, open=True) as p:
+ try:
+ p.wait(timeout=2)
+ finally:
+ print(p.get_stats())
+ finally:
+ logger.removeHandler(handler)
+ logger.setLevel(old_level)
+
+
+def delay_connection(monkeypatch, sec):
+ """
+ Return a _connect_gen function delayed by the amount of seconds
+ """
+
+ def connect_delay(*args, **kwargs):
+ t0 = time()
+ rv = connect_orig(*args, **kwargs)
+ t1 = time()
+ sleep(max(0, sec - (t1 - t0)))
+ return rv
+
+ connect_orig = psycopg.Connection.connect
+ monkeypatch.setattr(psycopg.Connection, "connect", connect_delay)
+
+
+def ensure_waiting(p, num=1):
+ """
+ Wait until there are at least *num* clients waiting in the queue.
+ """
+ while len(p._waiting) < num:
+ sleep(0)
diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py
new file mode 100644
index 0000000..286a775
--- /dev/null
+++ b/tests/pool/test_pool_async.py
@@ -0,0 +1,1198 @@
+import asyncio
+import logging
+from time import time
+from typing import Any, List, Tuple
+
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import create_task, Counter
+
+try:
+ import psycopg_pool as pool
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+pytestmark = [pytest.mark.asyncio]
+
+
+async def test_defaults(dsn):
+ async with pool.AsyncConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 4
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)])
+async def test_min_size_max_size(dsn, min_size, max_size):
+ async with pool.AsyncConnectionPool(dsn, min_size=min_size, max_size=max_size) as p:
+ assert p.min_size == min_size
+ assert p.max_size == max_size if max_size is not None else min_size
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
+async def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ pool.AsyncConnectionPool(min_size=min_size, max_size=max_size)
+
+
+async def test_connection_class(dsn):
+ class MyConn(psycopg.AsyncConnection[Any]):
+ pass
+
+ async with pool.AsyncConnectionPool(dsn, connection_class=MyConn, min_size=1) as p:
+ async with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+async def test_kwargs(dsn):
+ async with pool.AsyncConnectionPool(
+ dsn, kwargs={"autocommit": True}, min_size=1
+ ) as p:
+ async with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_its_really_a_pool(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ async with p.connection() as conn:
+ assert conn.info.backend_pid in (pid1, pid2)
+
+
+async def test_context(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_connection_not_lost(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ with pytest.raises(ZeroDivisionError):
+ async with p.connection() as conn:
+ pid = conn.info.backend_pid
+ 1 / 0
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_concurrent_filling(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+
+ async def add_time(self, conn):
+ times.append(time() - t0)
+ await add_orig(self, conn)
+
+ add_orig = pool.AsyncConnectionPool._add_to_pool
+ monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time)
+
+ times: List[float] = []
+ t0 = time()
+
+ async with pool.AsyncConnectionPool(dsn, min_size=5, num_workers=2) as p:
+ await p.wait(1.0)
+ want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
+ assert len(times) == len(want_times)
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.wait(0.3)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.wait(0.5)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=2) as p:
+ await p.wait(0.3)
+ await p.wait(0.0001) # idempotent
+
+
+async def test_wait_closed(dsn):
+ async with pool.AsyncConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(pool.PoolClosed):
+ await p.wait()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(pool.PoolTimeout):
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=1, num_workers=1
+ ) as p:
+ await p.wait(0.2)
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=1, num_workers=1
+ ) as p:
+ await asyncio.sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+
+async def test_configure(dsn):
+ inits = 0
+
+ async def configure(conn):
+ nonlocal inits
+ inits += 1
+ async with conn.transaction():
+ await conn.execute("set default_transaction_read_only to on")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
+ await p.wait(timeout=1.0)
+ async with p.connection() as conn:
+ assert inits == 1
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+ async with p.connection() as conn:
+ assert inits == 1
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+ await conn.close()
+
+ async with p.connection() as conn:
+ assert inits == 2
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+async def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ await conn.execute("select 1")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+async def test_reset(dsn):
+ resets = 0
+
+ async def setup(conn):
+ async with conn.transaction():
+ await conn.execute("set timezone to '+1:00'")
+
+ async def reset(conn):
+ nonlocal resets
+ resets += 1
+ async with conn.transaction():
+ await conn.execute("set timezone to utc")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+ assert resets == 0
+ await conn.execute("set timezone to '+2:00'")
+
+ await p.wait()
+ assert resets == 1
+
+ async with p.connection() as conn:
+ cur = await conn.execute("show timezone")
+ assert (await cur.fetchone()) == ("UTC",)
+
+ await p.wait()
+ assert resets == 2
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ await conn.execute("reset all")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue(dsn):
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ await p.wait()
+ ts = [create_task(worker(i)) for i in range(6)]
+ await asyncio.gather(*ts)
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_size(dsn):
+ async def worker(t, ev=None):
+ try:
+ async with p.connection():
+ if ev:
+ ev.set()
+ await asyncio.sleep(t)
+ except pool.TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, max_waiting=3) as p:
+ await p.wait()
+ ev = asyncio.Event()
+ create_task(worker(0.3, ev))
+ await ev.wait()
+
+ ts = [create_task(worker(0.1)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], pool.TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout(dsn):
+ async def worker(n):
+ t0 = time()
+ try:
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_dead_client(dsn):
+ async def worker(i, timeout):
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except pool.PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ results: List[int] = []
+ ts = [
+ create_task(worker(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+ assert len(p._pool) == 2 # no connection was lost
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout_override(dsn):
+ async def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_broken_reconnect(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ await conn.close()
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+ pid = conn.info.backend_pid
+ await conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+ cur = await conn2.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ )
+ assert not await cur.fetchone()
+
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+async def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+ pid = conn.info.backend_pid
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+
+ async def bad_rollback():
+ conn.pgconn.finish()
+ await orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+async def test_close_no_tasks(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ assert p._sched_runner and not p._sched_runner.done()
+ assert p._workers
+ workers = p._workers[:]
+ for t in workers:
+ assert not t.done()
+
+ await p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert t.done()
+
+
+async def test_putconn_no_pool(aconn_cls, dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await aconn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ await p.putconn(conn)
+
+ await conn.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p1:
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p2:
+ conn = await p1.getconn()
+ with pytest.raises(ValueError):
+ await p2.putconn(conn)
+
+
+async def test_closed_getconn(dsn):
+ p = pool.AsyncConnectionPool(dsn, min_size=1)
+ assert not p.closed
+ async with p.connection():
+ pass
+
+ await p.close()
+ assert p.closed
+
+ with pytest.raises(pool.PoolClosed):
+ async with p.connection():
+ pass
+
+
+async def test_closed_putconn(dsn):
+ p = pool.AsyncConnectionPool(dsn, min_size=1)
+
+ async with p.connection() as conn:
+ pass
+ assert not conn.closed
+
+ async with p.connection() as conn:
+ await p.close()
+ assert conn.closed
+
+
+async def test_closed_queue(dsn):
+ async def w1():
+ async with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ await e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ async def w2():
+ try:
+ async with p.connection():
+ pass # unexpected
+ except pool.PoolClosed:
+ success.append("w2")
+
+ e1 = asyncio.Event()
+ e2 = asyncio.Event()
+
+ p = pool.AsyncConnectionPool(dsn, min_size=1)
+ await p.wait()
+ success: List[str] = []
+
+ t1 = create_task(w1())
+ # Wait until w1 has received a connection
+ await e1.wait()
+
+ t2 = create_task(w2())
+ # Wait until w2 is in the queue
+ await ensure_waiting(p)
+ await p.close()
+
+ # Wait for the workers to finish
+ e2.set()
+ await asyncio.gather(t1, t2)
+ assert len(success) == 2
+
+
+async def test_open_explicit(dsn):
+ p = pool.AsyncConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(pool.PoolClosed):
+ await p.getconn()
+
+ with pytest.raises(pool.PoolClosed, match="is not open yet"):
+ async with p.connection():
+ pass
+
+ await p.open()
+ try:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+ with pytest.raises(pool.PoolClosed, match="is already closed"):
+ await p.getconn()
+
+
+async def test_open_context(dsn):
+ p = pool.AsyncConnectionPool(dsn, open=False)
+ assert p.closed
+
+ async with p:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+async def test_open_no_op(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ try:
+ assert not p.closed
+ await p.open()
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_open_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ await p.open(wait=True, timeout=0.3)
+ finally:
+ await p.close()
+
+ p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ await p.open(wait=True, timeout=0.5)
+ finally:
+ await p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_open_as_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.open(wait=True, timeout=0.3)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.open(wait=True, timeout=0.5)
+
+
+async def test_reopen(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ await p.close()
+ assert p._sched_runner is None
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ await p.open()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize(
+ "min_size, want_times",
+ [
+ (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]),
+ (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]),
+ ],
+)
+async def test_grow(dsn, monkeypatch, min_size, want_times):
+ delay_connection(monkeypatch, 0.1)
+
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select 1 from pg_sleep(0.25)")
+ t1 = time()
+ results.append((n, t1 - t0))
+
+ async with pool.AsyncConnectionPool(
+ dsn, min_size=min_size, max_size=4, num_workers=3
+ ) as p:
+ await p.wait(1.0)
+ ts = []
+ results: List[Tuple[int, float]] = []
+
+ ts = [create_task(worker(i)) for i in range(len(want_times))]
+ await asyncio.gather(*ts)
+
+ times = [item[1] for item in results]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_shrink(dsn, monkeypatch):
+
+ from psycopg_pool.pool_async import ShrinkPool
+
+ results: List[Tuple[int, int]] = []
+
+ async def run_hacked(self, pool):
+ n0 = pool._nconns
+ await orig_run(self, pool)
+ n1 = pool._nconns
+ results.append((n0, n1))
+
+ orig_run = ShrinkPool._run
+ monkeypatch.setattr(ShrinkPool, "_run", run_hacked)
+
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.1)")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p:
+ await p.wait(5.0)
+ assert p.max_idle == 0.2
+
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(1)
+
+ assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
+
+
+@pytest.mark.slow
+async def test_reconnect(proxy, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+ assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
+
+ caplog.clear()
+ proxy.start()
+ async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=1) as p:
+ await p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ await asyncio.sleep(1.0)
+ proxy.start()
+ await p.wait()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ assert "BAD" in caplog.messages[0]
+ times = [rec.created for rec in caplog.records]
+ assert times[1] - times[0] < 0.05
+ deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)]
+ assert len(deltas) == 3
+ want = 0.1
+ for delta in deltas:
+ assert delta == pytest.approx(want, 0.05), deltas
+ want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_reconnect_failure(proxy):
+ proxy.start()
+
+ t1 = None
+
+ def failed(pool):
+ assert pool.name == "this-one"
+ nonlocal t1
+ t1 = time()
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn,
+ name="this-one",
+ min_size=1,
+ reconnect_timeout=1.0,
+ reconnect_failed=failed,
+ ) as p:
+ await p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ t0 = time()
+ await asyncio.sleep(1.5)
+ assert t1
+ assert t1 - t0 == pytest.approx(1.0, 0.1)
+ assert p._nconns == 0
+
+ proxy.start()
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ t1 = time()
+ assert t1 - t0 < 0.2
+
+
+@pytest.mark.slow
+async def test_reconnect_after_grow_failed(proxy):
+ # Retry reconnection after a failed connection attempt has put the pool
+ # in grow mode. See issue #370.
+ proxy.stop()
+
+ ev = asyncio.Event()
+
+ def failed(pool):
+ ev.set()
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ await asyncio.wait_for(ev.wait(), 2.0)
+
+ with pytest.raises(pool.PoolTimeout):
+ async with p.connection(timeout=0.5) as conn:
+ pass
+
+ ev.clear()
+ await asyncio.wait_for(ev.wait(), 2.0)
+
+ proxy.start()
+
+ async with p.connection(timeout=2) as conn:
+ await conn.execute("select 1")
+
+ await p.wait(timeout=3.0)
+ assert len(p._pool) == p.min_size == 4
+
+
+@pytest.mark.slow
+async def test_refill_on_check(proxy):
+ proxy.start()
+ ev = asyncio.Event()
+
+ def failed(pool):
+ ev.set()
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ # The pool is full
+ await p.wait(timeout=2)
+
+ # Break all the connection
+ proxy.stop()
+
+ # Checking the pool will empty it
+ await p.check()
+ await asyncio.wait_for(ev.wait(), 2.0)
+ assert len(p._pool) == 0
+
+ # Allow to connect again
+ proxy.start()
+
+ # Make sure that check has refilled the pool
+ await p.check()
+ await p.wait(timeout=2)
+ assert len(p._pool) == 4
+
+
+@pytest.mark.slow
+async def test_uniform_use(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=4) as p:
+ counts = Counter[int]()
+ for i in range(8):
+ async with p.connection() as conn:
+ await asyncio.sleep(0.1)
+ counts[id(conn)] += 1
+
+ assert len(counts) == 4
+ assert set(counts.values()) == set([2])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_resize(dsn):
+ async def sampler():
+ await asyncio.sleep(0.05) # ensure sampling happens after shrink check
+ while True:
+ await asyncio.sleep(0.2)
+ if p.closed:
+ break
+ size.append(len(p._pool))
+
+ async def client(t):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(%s)", [t])
+
+ size: List[int] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, max_idle=0.2) as p:
+ s = create_task(sampler())
+
+ await asyncio.sleep(0.3)
+
+ c = create_task(client(0.4))
+
+ await asyncio.sleep(0.2)
+ await p.resize(4)
+ assert p.min_size == 4
+ assert p.max_size == 4
+
+ await asyncio.sleep(0.4)
+ await p.resize(2)
+ assert p.min_size == 2
+ assert p.max_size == 2
+
+ await asyncio.sleep(0.6)
+
+ await asyncio.gather(s, c)
+ assert size == [2, 1, 3, 4, 3, 2, 2]
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)])
+async def test_bad_resize(dsn, min_size, max_size):
+ async with pool.AsyncConnectionPool() as p:
+ with pytest.raises(ValueError):
+ await p.resize(min_size=min_size, max_size=max_size)
+
+
+async def test_jitter():
+ rnds = [pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)]
+ assert 27 <= min(rnds) <= 28
+ assert 35 < max(rnds) < 36
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_max_lifetime(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p:
+ await asyncio.sleep(0.1)
+ pids = []
+ for i in range(5):
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ await asyncio.sleep(0.2)
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_check(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ async with pool.AsyncConnectionPool(dsn, min_size=4) as p:
+ await p.wait(1.0)
+ async with p.connection() as conn:
+ pid = conn.info.backend_pid
+
+ await p.wait(1.0)
+ pids = set(conn.info.backend_pid for conn in p._pool)
+ assert pid in pids
+ await conn.close()
+
+ assert len(caplog.records) == 0
+ await p.check()
+ assert len(caplog.records) == 1
+ await p.wait(1.0)
+ pids2 = set(conn.info.backend_pid for conn in p._pool)
+ assert len(pids & pids2) == 3
+ assert pid not in pids2
+
+
+async def test_check_idle(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ await p.wait(1.0)
+ await p.check()
+ async with p.connection() as conn:
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_measures(dsn):
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4) as p:
+ await p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 2
+ assert stats["pool_available"] == 2
+ assert stats["requests_waiting"] == 0
+
+ ts = [create_task(worker(i)) for i in range(3)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ await p.wait(2.0)
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_usage(dsn):
+ async def worker(n):
+ try:
+ async with p.connection(timeout=0.3) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ except pool.PoolTimeout:
+ pass
+
+ async with pool.AsyncConnectionPool(dsn, min_size=3) as p:
+ await p.wait(2.0)
+
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.gather(*ts)
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ async with p.connection() as conn:
+ await conn.close()
+ await p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ async with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+async def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=3) as p:
+ await p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 3
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 580 <= stats["connections_ms"] < 1200
+
+ proxy.stop()
+ await p.check()
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ assert stats["connections_num"] > 3
+ assert stats["connections_errors"] > 0
+ assert stats["connections_lost"] == 3
+
+
+@pytest.mark.slow
+async def test_spike(dsn, monkeypatch):
+ # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/
+ # documents/Welcome-To-The-Jungle.md
+ delay_connection(monkeypatch, 0.15)
+
+ async def worker():
+ async with p.connection():
+ await asyncio.sleep(0.002)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=5, max_size=10) as p:
+ await p.wait()
+
+ ts = [create_task(worker()) for i in range(50)]
+ await asyncio.gather(*ts)
+ await p.wait()
+
+ assert len(p._pool) < 7
+
+
+async def test_debug_deadlock(dsn):
+ # https://github.com/psycopg/psycopg/issues/230
+ logger = logging.getLogger("psycopg")
+ handler = logging.StreamHandler()
+ old_level = logger.level
+ logger.setLevel(logging.DEBUG)
+ handler.setLevel(logging.DEBUG)
+ logger.addHandler(handler)
+ try:
+ async with pool.AsyncConnectionPool(dsn, min_size=4, open=True) as p:
+ await p.wait(timeout=2)
+ finally:
+ logger.removeHandler(handler)
+ logger.setLevel(old_level)
+
+
+def delay_connection(monkeypatch, sec):
+ """
+ Return a _connect_gen function delayed by the amount of seconds
+ """
+
+ async def connect_delay(*args, **kwargs):
+ t0 = time()
+ rv = await connect_orig(*args, **kwargs)
+ t1 = time()
+ await asyncio.sleep(max(0, sec - (t1 - t0)))
+ return rv
+
+ connect_orig = psycopg.AsyncConnection.connect
+ monkeypatch.setattr(psycopg.AsyncConnection, "connect", connect_delay)
+
+
+async def ensure_waiting(p, num=1):
+ while len(p._waiting) < num:
+ await asyncio.sleep(0)
diff --git a/tests/pool/test_pool_async_noasyncio.py b/tests/pool/test_pool_async_noasyncio.py
new file mode 100644
index 0000000..f6e34e4
--- /dev/null
+++ b/tests/pool/test_pool_async_noasyncio.py
@@ -0,0 +1,78 @@
+# These tests relate to AsyncConnectionPool, but are not marked asyncio
+# because they rely on the pool initialization outside the asyncio loop.
+
+import asyncio
+
+import pytest
+
+from ..utils import gc_collect
+
+try:
+ import psycopg_pool as pool
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+
+@pytest.mark.slow
+def test_reconnect_after_max_lifetime(dsn, asyncio_run):
+ # See issue #219, pool created before the loop.
+ p = pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2, open=False)
+
+ async def test():
+ try:
+ await p.open()
+ ns = []
+ for i in range(5):
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ ns.append(await cur.fetchone())
+ await asyncio.sleep(0.2)
+ assert len(ns) == 5
+ finally:
+ await p.close()
+
+ asyncio_run(asyncio.wait_for(test(), timeout=2.0))
+
+
+@pytest.mark.slow
+def test_working_created_before_loop(dsn, asyncio_run):
+ p = pool.AsyncNullConnectionPool(dsn, open=False)
+
+ async def test():
+ try:
+ await p.open()
+ ns = []
+ for i in range(5):
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ ns.append(await cur.fetchone())
+ await asyncio.sleep(0.2)
+ assert len(ns) == 5
+ finally:
+ await p.close()
+
+ asyncio_run(asyncio.wait_for(test(), timeout=2.0))
+
+
+def test_cant_create_open_outside_loop(dsn):
+ with pytest.raises(RuntimeError):
+ pool.AsyncConnectionPool(dsn, open=True)
+
+
+@pytest.fixture
+def asyncio_run(recwarn):
+ """Fixture reuturning asyncio.run, but managing resources at exit.
+
+ In certain runs, fd objects are leaked and the error will only be caught
+ downstream, by some innocent test calling gc_collect().
+ """
+ recwarn.clear()
+ try:
+ yield asyncio.run
+ finally:
+ gc_collect()
+ if recwarn:
+ warn = recwarn.pop(ResourceWarning)
+ assert "unclosed event loop" in str(warn.message)
+ assert not recwarn
diff --git a/tests/pool/test_sched.py b/tests/pool/test_sched.py
new file mode 100644
index 0000000..b3d2572
--- /dev/null
+++ b/tests/pool/test_sched.py
@@ -0,0 +1,154 @@
+import logging
+from time import time, sleep
+from functools import partial
+from threading import Thread
+
+import pytest
+
+try:
+ from psycopg_pool.sched import Scheduler
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+pytestmark = [pytest.mark.timing]
+
+
+@pytest.mark.slow
+def test_sched():
+ s = Scheduler()
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, partial(worker, 3))
+ s.enter(0.3, None)
+ s.enter(0.2, partial(worker, 2))
+ s.run()
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+def test_sched_thread():
+ s = Scheduler()
+ t = Thread(target=s.run, daemon=True)
+ t.start()
+
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, partial(worker, 3))
+ s.enter(0.3, None)
+ s.enter(0.2, partial(worker, 2))
+
+ t.join()
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.3, 0.2)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.2)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.2)
+
+
+@pytest.mark.slow
+def test_sched_error(caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ s = Scheduler()
+ t = Thread(target=s.run, daemon=True)
+ t.start()
+
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ def error():
+ 1 / 0
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, None)
+ s.enter(0.3, partial(worker, 2))
+ s.enter(0.2, error)
+
+ t.join()
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.4, 0.1)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.3, 0.1)
+
+ assert len(caplog.records) == 1
+ assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_empty_queue_timeout():
+ s = Scheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ def wait_logging(timeout=None):
+ rv = wait_orig(timeout)
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+ t = Thread(target=s.run)
+ t.start()
+ sleep(0.5)
+ s.enter(0.5, None)
+ t.join()
+ times.append(time() - t0)
+ for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+ assert got == pytest.approx(want, 0.2), times
+
+
+@pytest.mark.slow
+def test_first_task_rescheduling():
+ s = Scheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ def wait_logging(timeout=None):
+ rv = wait_orig(timeout)
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+ s.enter(0.4, lambda: None)
+ t = Thread(target=s.run)
+ t.start()
+ s.enter(0.6, None) # this task doesn't trigger a reschedule
+ sleep(0.1)
+ s.enter(0.1, lambda: None) # this triggers a reschedule
+ t.join()
+ times.append(time() - t0)
+ for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+ assert got == pytest.approx(want, 0.2), times
diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py
new file mode 100644
index 0000000..492d620
--- /dev/null
+++ b/tests/pool/test_sched_async.py
@@ -0,0 +1,159 @@
+import asyncio
+import logging
+from time import time
+from functools import partial
+
+import pytest
+
+from psycopg._compat import create_task
+
+try:
+ from psycopg_pool.sched import AsyncScheduler
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+pytestmark = [pytest.mark.asyncio, pytest.mark.timing]
+
+
+@pytest.mark.slow
+async def test_sched():
+ s = AsyncScheduler()
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, partial(worker, 3))
+ await s.enter(0.3, None)
+ await s.enter(0.2, partial(worker, 2))
+ await s.run()
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+async def test_sched_task():
+ s = AsyncScheduler()
+ t = create_task(s.run())
+
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, partial(worker, 3))
+ await s.enter(0.3, None)
+ await s.enter(0.2, partial(worker, 2))
+
+ await asyncio.gather(t)
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.3, 0.2)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.2)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.2)
+
+
+@pytest.mark.slow
+async def test_sched_error(caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ s = AsyncScheduler()
+ t = create_task(s.run())
+
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ async def error():
+ 1 / 0
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, None)
+ await s.enter(0.3, partial(worker, 2))
+ await s.enter(0.2, error)
+
+ await asyncio.gather(t)
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.4, 0.1)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.3, 0.1)
+
+ assert len(caplog.records) == 1
+ assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_empty_queue_timeout():
+ s = AsyncScheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ async def wait_logging():
+ try:
+ rv = await wait_orig()
+ finally:
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+ t = create_task(s.run())
+ await asyncio.sleep(0.5)
+ await s.enter(0.5, None)
+ await asyncio.gather(t)
+ times.append(time() - t0)
+ for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+ assert got == pytest.approx(want, 0.2), times
+
+
+@pytest.mark.slow
+async def test_first_task_rescheduling():
+ s = AsyncScheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ async def wait_logging():
+ try:
+ rv = await wait_orig()
+ finally:
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+ async def noop():
+ pass
+
+ await s.enter(0.4, noop)
+ t = create_task(s.run())
+ await s.enter(0.6, None) # this task doesn't trigger a reschedule
+ await asyncio.sleep(0.1)
+ await s.enter(0.1, noop) # this triggers a reschedule
+ await asyncio.gather(t)
+ times.append(time() - t0)
+ for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+ assert got == pytest.approx(want, 0.2), times
diff --git a/tests/pq/__init__.py b/tests/pq/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/pq/__init__.py
diff --git a/tests/pq/test_async.py b/tests/pq/test_async.py
new file mode 100644
index 0000000..2c3de98
--- /dev/null
+++ b/tests/pq/test_async.py
@@ -0,0 +1,210 @@
+from select import select
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg.generators import execute
+
+
+def execute_wait(pgconn):
+ return psycopg.waiting.wait(execute(pgconn), pgconn.socket)
+
+
+def test_send_query(pgconn):
+ # This test shows how to process an async query in all its glory
+ pgconn.nonblocking = 1
+
+ # Long query to make sure we have to wait on send
+ pgconn.send_query(
+ b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;"
+ % (b"x" * 1_000_000)
+ )
+
+ # send loop
+ waited_on_send = 0
+ while True:
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ waited_on_send += 1
+
+ rl, wl, xl = select([pgconn.socket], [pgconn.socket], [])
+ assert not (rl and wl)
+ if wl:
+ continue # call flush again()
+ if rl:
+ pgconn.consume_input()
+ continue
+
+ # TODO: this check is not reliable, it fails on travis sometimes
+ # assert waited_on_send
+
+ # read loop
+ results = []
+ while True:
+ pgconn.consume_input()
+ if pgconn.is_busy():
+ select([pgconn.socket], [], [])
+ continue
+ res = pgconn.get_result()
+ if res is None:
+ break
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ results.append(res)
+
+ assert len(results) == 2
+ assert results[0].nfields == 1
+ assert results[0].fname(0) == b"f"
+ assert results[0].get_value(0, 0) == b"x"
+ assert results[1].nfields == 1
+ assert results[1].fname(0) == b"foo"
+ assert results[1].get_value(0, 0) == b"1"
+
+
+def test_send_query_compact_test(pgconn):
+ # Like the above test but use psycopg facilities for compactness
+ pgconn.send_query(
+ b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;"
+ % (b"x" * 1_000_000)
+ )
+ results = execute_wait(pgconn)
+
+ assert len(results) == 2
+ assert results[0].nfields == 1
+ assert results[0].fname(0) == b"f"
+ assert results[0].get_value(0, 0) == b"x"
+ assert results[1].nfields == 1
+ assert results[1].fname(0) == b"foo"
+ assert results[1].get_value(0, 0) == b"1"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_query(b"select 1")
+
+
+def test_single_row_mode(pgconn):
+ pgconn.send_query(b"select generate_series(1,2)")
+ pgconn.set_single_row_mode()
+
+ results = execute_wait(pgconn)
+ assert len(results) == 3
+
+ res = results[0]
+ assert res.status == pq.ExecStatus.SINGLE_TUPLE
+ assert res.ntuples == 1
+ assert res.get_value(0, 0) == b"1"
+
+ res = results[1]
+ assert res.status == pq.ExecStatus.SINGLE_TUPLE
+ assert res.ntuples == 1
+ assert res.get_value(0, 0) == b"2"
+
+ res = results[2]
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.ntuples == 0
+
+
+def test_send_query_params(pgconn):
+ pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_query_params(b"select $1", [b"1"])
+
+
+def test_send_prepare(pgconn):
+ pgconn.send_prepare(b"prep", b"select $1::int + $2::int")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"prep", [b"3", b"5"])
+ (res,) = execute_wait(pgconn)
+ assert res.get_value(0, 0) == b"8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_prepare(b"prep", b"select $1::int + $2::int")
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_query_prepared(b"prep", [b"3", b"5"])
+
+
+def test_send_prepare_types(pgconn):
+ pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23])
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"prep", [b"3", b"5"])
+ (res,) = execute_wait(pgconn)
+ assert res.get_value(0, 0) == b"8"
+
+
+def test_send_prepared_binary_in(pgconn):
+ val = b"foo\00bar"
+ pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1])
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"3"
+ assert res.get_value(0, 1) == b"7"
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
+
+
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
+def test_send_prepared_binary_out(pgconn, fmt, out):
+ val = b"foo\00bar"
+ pgconn.send_prepare(b"", b"select $1::bytea")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"", [val], param_formats=[1], result_format=fmt)
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == out
+
+
+def test_send_describe_prepared(pgconn):
+ pgconn.send_prepare(b"prep", b"select $1::int8 + $2::int8 as fld")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_describe_prepared(b"prep")
+ (res,) = execute_wait(pgconn)
+ assert res.nfields == 1
+ assert res.ntuples == 0
+ assert res.fname(0) == b"fld"
+ assert res.ftype(0) == 20
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_describe_prepared(b"prep")
+
+
+@pytest.mark.crdb_skip("server-side cursor")
+def test_send_describe_portal(pgconn):
+ res = pgconn.exec_(
+ b"""
+ begin;
+ declare cur cursor for select * from generate_series(1,10) foo;
+ """
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_describe_portal(b"cur")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ assert res.nfields == 1
+ assert res.fname(0) == b"foo"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_describe_portal(b"cur")
diff --git a/tests/pq/test_conninfo.py b/tests/pq/test_conninfo.py
new file mode 100644
index 0000000..64d8b8f
--- /dev/null
+++ b/tests/pq/test_conninfo.py
@@ -0,0 +1,48 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+def test_defaults(monkeypatch):
+ monkeypatch.setenv("PGPORT", "15432")
+ defs = pq.Conninfo.get_defaults()
+ assert len(defs) > 20
+ port = [d for d in defs if d.keyword == b"port"][0]
+ assert port.envvar == b"PGPORT"
+ assert port.compiled == b"5432"
+ assert port.val == b"15432"
+ assert port.label == b"Database-Port"
+ assert port.dispchar == b""
+ assert port.dispsize == 6
+
+
+@pytest.mark.libpq(">= 10")
+def test_conninfo_parse():
+ infos = pq.Conninfo.parse(
+ b"postgresql://host1:123,host2:456/somedb"
+ b"?target_session_attrs=any&application_name=myapp"
+ )
+ info = {i.keyword: i.val for i in infos if i.val is not None}
+ assert info[b"host"] == b"host1,host2"
+ assert info[b"port"] == b"123,456"
+ assert info[b"dbname"] == b"somedb"
+ assert info[b"application_name"] == b"myapp"
+
+
+@pytest.mark.libpq("< 10")
+def test_conninfo_parse_96():
+ conninfo = pq.Conninfo.parse(
+ b"postgresql://other@localhost/otherdb"
+ b"?connect_timeout=10&application_name=myapp"
+ )
+ info = {i.keyword: i.val for i in conninfo if i.val is not None}
+ assert info[b"host"] == b"localhost"
+ assert info[b"dbname"] == b"otherdb"
+ assert info[b"application_name"] == b"myapp"
+
+
+def test_conninfo_parse_bad():
+ with pytest.raises(psycopg.OperationalError) as e:
+ pq.Conninfo.parse(b"bad_conninfo=")
+ assert "bad_conninfo" in str(e.value)
diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py
new file mode 100644
index 0000000..383d272
--- /dev/null
+++ b/tests/pq/test_copy.py
@@ -0,0 +1,174 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+pytestmark = pytest.mark.crdb_skip("copy")
+
+sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
+
+sample_tabledef = "col1 int primary key, col2 int, data text"
+
+sample_text = b"""\
+10\t20\thello
+40\t\\N\tworld
+"""
+
+sample_binary_value = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+ bytes.fromhex("".join(row.split())) for row in sample_binary_value.split("\n\n")
+]
+
+sample_binary = b"".join(sample_binary_rows)
+
+
+def test_put_data_no_copy(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_data(b"wat")
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_data(b"wat")
+
+
+def test_put_end_no_copy(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_end()
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_end()
+
+
+def test_copy_out(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end()
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_(
+ b"select min(col1), max(col1), count(*), max(length(data)) from copy_in"
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+ assert res.get_value(0, 1) == b"199"
+ assert res.get_value(0, 2) == b"200"
+ assert res.get_value(0, 3) == b"199"
+
+
+def test_copy_out_err(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\thardly a number\tnope
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end()
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"hardly a number" in res.error_message
+
+ res = pgconn.exec_(b"select count(*) from copy_in")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+
+
+def test_copy_out_error_end(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end(b"nuttengoggenio")
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"nuttengoggenio" in res.error_message
+
+ res = pgconn.exec_(b"select count(*) from copy_in")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+
+
+def test_get_data_no_copy(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.get_copy_data(0)
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.get_copy_data(0)
+
+
+@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
+def test_copy_out_read(pgconn, format):
+ stmt = f"copy ({sample_values}) to stdout (format {format.name})"
+ res = pgconn.exec_(stmt.encode("ascii"))
+ assert res.status == pq.ExecStatus.COPY_OUT
+ assert res.binary_tuples == format
+
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ for row in want:
+ nbytes, data = pgconn.get_copy_data(0)
+ assert nbytes == len(data)
+ assert data == row
+
+ assert pgconn.get_copy_data(0) == (-1, b"")
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+
+def ensure_table(pgconn, tabledef, name="copy_in"):
+ pgconn.exec_(f"drop table if exists {name}".encode("ascii"))
+ pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii"))
diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py
new file mode 100644
index 0000000..ad88d8a
--- /dev/null
+++ b/tests/pq/test_escaping.py
@@ -0,0 +1,188 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+from ..fix_crdb import crdb_scs_off
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b"''"),
+ (b"hello", b"'hello'"),
+ (b"foo'bar", b"'foo''bar'"),
+ (b"foo\\bar", b" E'foo\\\\bar'"),
+ ],
+)
+def test_escape_literal(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_literal(data)
+ assert out == want
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_escape_literal_1char(pgconn, scs):
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ esc = pq.Escaping(pgconn)
+ special = {b"'": b"''''", b"\\": b" E'\\\\'"}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_literal(data)
+ exp = special.get(data) or b"'%s'" % data
+ assert rv == exp
+
+
+def test_escape_literal_noconn(pgconn):
+ esc = pq.Escaping()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_literal(b"hi")
+
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_literal(b"hi")
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b'""'),
+ (b"hello", b'"hello"'),
+ (b'foo"bar', b'"foo""bar"'),
+ (b"foo\\bar", b'"foo\\bar"'),
+ ],
+)
+def test_escape_identifier(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_identifier(data)
+ assert out == want
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_escape_identifier_1char(pgconn, scs):
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ esc = pq.Escaping(pgconn)
+ special = {b'"': b'""""', b"\\": b'"\\"'}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_identifier(data)
+ exp = special.get(data) or b'"%s"' % data
+ assert rv == exp
+
+
+def test_escape_identifier_noconn(pgconn):
+ esc = pq.Escaping()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_identifier(b"hi")
+
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_identifier(b"hi")
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b""),
+ (b"hello", b"hello"),
+ (b"foo'bar", b"foo''bar"),
+ (b"foo\\bar", b"foo\\bar"),
+ ],
+)
+def test_escape_string(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_string(data)
+ assert out == want
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_escape_string_1char(pgconn, scs):
+ esc = pq.Escaping(pgconn)
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ special = {b"'": b"''", b"\\": b"\\" if scs == "on" else b"\\\\"}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_string(data)
+ exp = special.get(data) or b"%s" % data
+ assert rv == exp
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b""),
+ (b"hello", b"hello"),
+ (b"foo'bar", b"foo''bar"),
+ # This libpq function behaves unpredictably when not passed a conn
+ (b"foo\\bar", (b"foo\\\\bar", b"foo\\bar")),
+ ],
+)
+def test_escape_string_noconn(data, want):
+ esc = pq.Escaping()
+ out = esc.escape_string(data)
+ if isinstance(want, bytes):
+ assert out == want
+ else:
+ assert out in want
+
+
+def test_escape_string_badconn(pgconn):
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_string(b"hi")
+
+
+def test_escape_string_badenc(pgconn):
+ res = pgconn.exec_(b"set client_encoding to 'UTF8'")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ data = "\u20ac".encode()[:-1]
+ esc = pq.Escaping(pgconn)
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_string(data)
+
+
+@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"])
+def test_escape_bytea(pgconn, data):
+ exp = rb"\x" + b"".join(b"%02x" % c for c in data)
+ esc = pq.Escaping(pgconn)
+ rv = esc.escape_bytea(data)
+ assert rv == exp
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_bytea(data)
+
+
+def test_escape_noconn(pgconn):
+ data = bytes(range(256))
+ esc = pq.Escaping()
+ escdata = esc.escape_bytea(data)
+ res = pgconn.exec_params(b"select '%s'::bytea" % escdata, [], result_format=1)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == data
+
+
+def test_escape_1char(pgconn):
+ esc = pq.Escaping(pgconn)
+ for c in range(256):
+ rv = esc.escape_bytea(bytes([c]))
+ exp = rb"\x%02x" % c
+ assert rv == exp
+
+
+@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"])
+def test_unescape_bytea(pgconn, data):
+ enc = rb"\x" + b"".join(b"%02x" % c for c in data)
+ esc = pq.Escaping(pgconn)
+ rv = esc.unescape_bytea(enc)
+ assert rv == data
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.unescape_bytea(data)
diff --git a/tests/pq/test_exec.py b/tests/pq/test_exec.py
new file mode 100644
index 0000000..86c30c0
--- /dev/null
+++ b/tests/pq/test_exec.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+def test_exec_none(pgconn):
+ with pytest.raises(TypeError):
+ pgconn.exec_(None)
+
+
+def test_exec(pgconn):
+ res = pgconn.exec_(b"select 'hel' || 'lo'")
+ assert res.get_value(0, 0) == b"hello"
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.exec_(b"select 'hello'")
+
+
+def test_exec_params(pgconn):
+ res = pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"])
+
+
+def test_exec_params_empty(pgconn):
+ res = pgconn.exec_params(b"select 8::int", [])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+
+
+def test_exec_params_types(pgconn):
+ res = pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700, 23])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+ assert res.ftype(0) == 1700
+ assert res.get_value(0, 1) == b"8"
+ assert res.ftype(1) == 23
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700])
+
+
+def test_exec_params_nulls(pgconn):
+ res = pgconn.exec_params(b"select $1::text, $2::text, $3::text", [b"hi", b"", None])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"hi"
+ assert res.get_value(0, 1) == b""
+ assert res.get_value(0, 2) is None
+
+
+def test_exec_params_binary_in(pgconn):
+ val = b"foo\00bar"
+ res = pgconn.exec_params(
+ b"select length($1::bytea), length($2::bytea)",
+ [val, val],
+ param_formats=[0, 1],
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"3"
+ assert res.get_value(0, 1) == b"7"
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
+
+
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
+def test_exec_params_binary_out(pgconn, fmt, out):
+ val = b"foo\00bar"
+ res = pgconn.exec_params(
+ b"select $1::bytea", [val], param_formats=[1], result_format=fmt
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == out
+
+
+def test_prepare(pgconn):
+ res = pgconn.prepare(b"prep", b"select $1::int + $2::int")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"prep", [b"3", b"5"])
+ assert res.get_value(0, 0) == b"8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.prepare(b"prep", b"select $1::int + $2::int")
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.exec_prepared(b"prep", [b"3", b"5"])
+
+
+def test_prepare_types(pgconn):
+ res = pgconn.prepare(b"prep", b"select $1 + $2", [23, 23])
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"prep", [b"3", b"5"])
+ assert res.get_value(0, 0) == b"8"
+
+
+def test_exec_prepared_binary_in(pgconn):
+ val = b"foo\00bar"
+ res = pgconn.prepare(b"", b"select length($1::bytea), length($2::bytea)")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"", [val, val], param_formats=[0, 1])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"3"
+ assert res.get_value(0, 1) == b"7"
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
+
+
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
+def test_exec_prepared_binary_out(pgconn, fmt, out):
+ val = b"foo\00bar"
+ res = pgconn.prepare(b"", b"select $1::bytea")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"", [val], param_formats=[1], result_format=fmt)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == out
+
+
+@pytest.mark.crdb_skip("server-side cursor")
+def test_describe_portal(pgconn):
+ res = pgconn.exec_(
+ b"""
+ begin;
+ declare cur cursor for select * from generate_series(1,10) foo;
+ """
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.describe_portal(b"cur")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ assert res.nfields == 1
+ assert res.fname(0) == b"foo"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.describe_portal(b"cur")
diff --git a/tests/pq/test_misc.py b/tests/pq/test_misc.py
new file mode 100644
index 0000000..599758f
--- /dev/null
+++ b/tests/pq/test_misc.py
@@ -0,0 +1,83 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+def test_error_message(pgconn):
+ res = pgconn.exec_(b"wat")
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ msg = pq.error_message(pgconn)
+ assert "wat" in msg
+ assert msg == pq.error_message(res)
+ primary = res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+ assert primary.decode("ascii") in msg
+
+ with pytest.raises(TypeError):
+ pq.error_message(None) # type: ignore[arg-type]
+
+ res.clear()
+ assert pq.error_message(res) == "no details available"
+ pgconn.finish()
+ assert "NULL" in pq.error_message(pgconn)
+
+
+@pytest.mark.crdb_skip("encoding")
+def test_error_message_encoding(pgconn):
+ res = pgconn.exec_(b"set client_encoding to latin9")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+
+ res = pgconn.exec_('select 1 from "foo\u20acbar"'.encode("latin9"))
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+
+ msg = pq.error_message(pgconn)
+ assert "foo\u20acbar" in msg
+
+ msg = pq.error_message(res)
+ assert "foo\ufffdbar" in msg
+
+ msg = pq.error_message(res, encoding="latin9")
+ assert "foo\u20acbar" in msg
+
+ msg = pq.error_message(res, encoding="ascii")
+ assert "foo\ufffdbar" in msg
+
+
+def test_make_empty_result(pgconn):
+ pgconn.exec_(b"wat")
+ res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"wat" in res.error_message
+
+ pgconn.finish()
+ res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert res.error_message == b""
+
+
+def test_result_set_attrs(pgconn):
+ res = pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
+ assert res.status == pq.ExecStatus.COPY_OUT
+
+ attrs = [
+ pq.PGresAttDesc(b"an_int", 0, 0, 0, 23, 0, 0),
+ pq.PGresAttDesc(b"a_num", 0, 0, 0, 1700, 0, 0),
+ pq.PGresAttDesc(b"a_bin_text", 0, 0, 1, 25, 0, 0),
+ ]
+ res.set_attributes(attrs)
+ assert res.nfields == 3
+
+ assert res.fname(0) == b"an_int"
+ assert res.fname(1) == b"a_num"
+ assert res.fname(2) == b"a_bin_text"
+
+ assert res.fformat(0) == 0
+ assert res.fformat(1) == 0
+ assert res.fformat(2) == 1
+
+ assert res.ftype(0) == 23
+ assert res.ftype(1) == 1700
+ assert res.ftype(2) == 25
+
+ with pytest.raises(psycopg.OperationalError):
+ res.set_attributes(attrs)
diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py
new file mode 100644
index 0000000..0566151
--- /dev/null
+++ b/tests/pq/test_pgconn.py
@@ -0,0 +1,585 @@
+import os
+import sys
+import ctypes
+import logging
+import weakref
+from select import select
+
+import pytest
+
+import psycopg
+from psycopg import pq
+import psycopg.generators
+
+from ..utils import gc_collect
+
+
+def test_connectdb(dsn):
+ conn = pq.PGconn.connect(dsn.encode())
+ assert conn.status == pq.ConnStatus.OK, conn.error_message
+
+
+def test_connectdb_error():
+ conn = pq.PGconn.connect(b"dbname=psycopg_test_not_for_real")
+ assert conn.status == pq.ConnStatus.BAD
+
+
+@pytest.mark.parametrize("baddsn", [None, 42])
+def test_connectdb_badtype(baddsn):
+ with pytest.raises(TypeError):
+ pq.PGconn.connect(baddsn)
+
+
+def test_connect_async(dsn):
+ conn = pq.PGconn.connect_start(dsn.encode())
+ conn.nonblocking = 1
+ while True:
+ assert conn.status != pq.ConnStatus.BAD
+ rv = conn.connect_poll()
+ if rv == pq.PollingStatus.OK:
+ break
+ elif rv == pq.PollingStatus.READING:
+ select([conn.socket], [], [])
+ elif rv == pq.PollingStatus.WRITING:
+ select([], [conn.socket], [])
+ else:
+ assert False, rv
+
+ assert conn.status == pq.ConnStatus.OK
+
+ conn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ conn.connect_poll()
+
+
+@pytest.mark.crdb("skip", reason="connects to any db name")
+def test_connect_async_bad(dsn):
+ parsed_dsn = {e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val}
+ parsed_dsn[b"dbname"] = b"psycopg_test_not_for_real"
+ dsn = b" ".join(b"%s='%s'" % item for item in parsed_dsn.items())
+ conn = pq.PGconn.connect_start(dsn)
+ while True:
+ assert conn.status != pq.ConnStatus.BAD, conn.error_message
+ rv = conn.connect_poll()
+ if rv == pq.PollingStatus.FAILED:
+ break
+ elif rv == pq.PollingStatus.READING:
+ select([conn.socket], [], [])
+ elif rv == pq.PollingStatus.WRITING:
+ select([], [conn.socket], [])
+ else:
+ assert False, rv
+
+ assert conn.status == pq.ConnStatus.BAD
+
+
+def test_finish(pgconn):
+ assert pgconn.status == pq.ConnStatus.OK
+ pgconn.finish()
+ assert pgconn.status == pq.ConnStatus.BAD
+ pgconn.finish()
+ assert pgconn.status == pq.ConnStatus.BAD
+
+
+@pytest.mark.slow
+def test_weakref(dsn):
+ conn = pq.PGconn.connect(dsn.encode())
+ w = weakref.ref(conn)
+ conn.finish()
+ del conn
+ gc_collect()
+ assert w() is None
+
+
+@pytest.mark.skipif(
+ sys.platform == "win32"
+ and os.environ.get("CI") == "true"
+ and pq.__impl__ != "python",
+ reason="can't figure out how to make ctypes run, don't care",
+)
+def test_pgconn_ptr(pgconn, libpq):
+ assert isinstance(pgconn.pgconn_ptr, int)
+
+ f = libpq.PQserverVersion
+ f.argtypes = [ctypes.c_void_p]
+ f.restype = ctypes.c_int
+ ver = f(pgconn.pgconn_ptr)
+ assert ver == pgconn.server_version
+
+ pgconn.finish()
+ assert pgconn.pgconn_ptr is None
+
+
+def test_info(dsn, pgconn):
+ info = pgconn.info
+ assert len(info) > 20
+ dbname = [d for d in info if d.keyword == b"dbname"][0]
+ assert dbname.envvar == b"PGDATABASE"
+ assert dbname.label == b"Database-Name"
+ assert dbname.dispchar == b""
+ assert dbname.dispsize == 20
+
+ parsed = pq.Conninfo.parse(dsn.encode())
+ # take the name and the user either from params or from env vars
+ name = [
+ o.val or os.environ.get(o.envvar.decode(), "").encode()
+ for o in parsed
+ if o.keyword == b"dbname" and o.envvar
+ ][0]
+ user = [
+ o.val or os.environ.get(o.envvar.decode(), "").encode()
+ for o in parsed
+ if o.keyword == b"user" and o.envvar
+ ][0]
+ assert dbname.val == (name or user)
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.info
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_reset(pgconn):
+ assert pgconn.status == pq.ConnStatus.OK
+ pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())")
+ assert pgconn.status == pq.ConnStatus.BAD
+ pgconn.reset()
+ assert pgconn.status == pq.ConnStatus.OK
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.reset()
+
+ assert pgconn.status == pq.ConnStatus.BAD
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_reset_async(pgconn):
+ assert pgconn.status == pq.ConnStatus.OK
+ pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())")
+ assert pgconn.status == pq.ConnStatus.BAD
+ pgconn.reset_start()
+ while True:
+ rv = pgconn.reset_poll()
+ if rv == pq.PollingStatus.READING:
+ select([pgconn.socket], [], [])
+ elif rv == pq.PollingStatus.WRITING:
+ select([], [pgconn.socket], [])
+ else:
+ break
+
+ assert rv == pq.PollingStatus.OK
+ assert pgconn.status == pq.ConnStatus.OK
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.reset_start()
+
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.reset_poll()
+
+
+def test_ping(dsn):
+ rv = pq.PGconn.ping(dsn.encode())
+ assert rv == pq.Ping.OK
+
+ rv = pq.PGconn.ping(b"port=9999")
+ assert rv == pq.Ping.NO_RESPONSE
+
+
+def test_db(pgconn):
+ name = [o.val for o in pgconn.info if o.keyword == b"dbname"][0]
+ assert pgconn.db == name
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.db
+
+
+def test_user(pgconn):
+ user = [o.val for o in pgconn.info if o.keyword == b"user"][0]
+ assert pgconn.user == user
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.user
+
+
+def test_password(pgconn):
+ # not in info
+ assert isinstance(pgconn.password, bytes)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.password
+
+
+def test_host(pgconn):
+ # might be not in info
+ assert isinstance(pgconn.host, bytes)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.host
+
+
+@pytest.mark.libpq(">= 12")
+def test_hostaddr(pgconn):
+ # not in info
+ assert isinstance(pgconn.hostaddr, bytes), pgconn.hostaddr
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.hostaddr
+
+
+@pytest.mark.libpq("< 12")
+def test_hostaddr_missing(pgconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.hostaddr
+
+
+def test_port(pgconn):
+ port = [o.val for o in pgconn.info if o.keyword == b"port"][0]
+ assert pgconn.port == port
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.port
+
+
+@pytest.mark.libpq("< 14")
+def test_tty(pgconn):
+ tty = [o.val for o in pgconn.info if o.keyword == b"tty"][0]
+ assert pgconn.tty == tty
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.tty
+
+
+@pytest.mark.libpq(">= 14")
+def test_tty_noop(pgconn):
+ assert not any(o.val for o in pgconn.info if o.keyword == b"tty")
+ assert pgconn.tty == b""
+
+
+def test_transaction_status(pgconn):
+ assert pgconn.transaction_status == pq.TransactionStatus.IDLE
+ pgconn.exec_(b"begin")
+ assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
+ pgconn.send_query(b"select 1")
+ assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE
+ psycopg.waiting.wait(psycopg.generators.execute(pgconn), pgconn.socket)
+ assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
+ pgconn.finish()
+ assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN
+
+
+def test_parameter_status(dsn, monkeypatch):
+ monkeypatch.setenv("PGAPPNAME", "psycopg tests")
+ pgconn = pq.PGconn.connect(dsn.encode())
+ assert pgconn.parameter_status(b"application_name") == b"psycopg tests"
+ assert pgconn.parameter_status(b"wat") is None
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.parameter_status(b"application_name")
+
+
+@pytest.mark.crdb_skip("encoding")
+def test_encoding(pgconn):
+ res = pgconn.exec_(b"set client_encoding to latin1")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert pgconn.parameter_status(b"client_encoding") == b"LATIN1"
+
+ res = pgconn.exec_(b"set client_encoding to 'utf-8'")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert pgconn.parameter_status(b"client_encoding") == b"UTF8"
+
+ res = pgconn.exec_(b"set client_encoding to wat")
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert pgconn.parameter_status(b"client_encoding") == b"UTF8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.parameter_status(b"client_encoding")
+
+
+def test_protocol_version(pgconn):
+ assert pgconn.protocol_version == 3
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.protocol_version
+
+
+def test_server_version(pgconn):
+ assert pgconn.server_version >= 90400
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.server_version
+
+
+def test_socket(pgconn):
+ socket = pgconn.socket
+ assert socket > 0
+ pgconn.exec_(f"select pg_terminate_backend({pgconn.backend_pid})".encode())
+ # TODO: on my box it raises OperationalError as it should. Not on Travis,
+ # so let's see if at least an ok value comes out of it.
+ try:
+ assert pgconn.socket == socket
+ except psycopg.OperationalError:
+ pass
+
+
+def test_error_message(pgconn):
+ assert pgconn.error_message == b""
+ res = pgconn.exec_(b"wat")
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ msg = pgconn.error_message
+ assert b"wat" in msg
+ pgconn.finish()
+ assert b"NULL" in pgconn.error_message # TODO: i10n?
+
+
+def test_backend_pid(pgconn):
+ assert isinstance(pgconn.backend_pid, int)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.backend_pid
+
+
+def test_needs_password(pgconn):
+ # assume connection worked so an eventually needed password wasn't missing
+ assert pgconn.needs_password is False
+ pgconn.finish()
+ pgconn.needs_password
+
+
+def test_used_password(pgconn, dsn, monkeypatch):
+ assert isinstance(pgconn.used_password, bool)
+
+ # Assume that if a password was passed then it was needed.
+ # Note that the server may still need a password passed via pgpass
+ # so it may be that has_password is false but still a password was
+ # requested by the server and passed by libpq.
+ info = pq.Conninfo.parse(dsn.encode())
+ has_password = (
+ "PGPASSWORD" in os.environ
+ or [i for i in info if i.keyword == b"password"][0].val is not None
+ )
+ if has_password:
+ assert pgconn.used_password
+
+ pgconn.finish()
+ pgconn.used_password
+
+
+def test_ssl_in_use(pgconn):
+ assert isinstance(pgconn.ssl_in_use, bool)
+
+ # If connecting via socket then ssl is not in use
+ if pgconn.host.startswith(b"/"):
+ assert not pgconn.ssl_in_use
+ else:
+ sslmode = [i.val for i in pgconn.info if i.keyword == b"sslmode"][0]
+ if sslmode not in (b"disable", b"allow", b"prefer"):
+ assert pgconn.ssl_in_use
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.ssl_in_use
+
+
+def test_set_single_row_mode(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.set_single_row_mode()
+
+ pgconn.send_query(b"select 1")
+ pgconn.set_single_row_mode()
+
+
+def test_cancel(pgconn):
+ cancel = pgconn.get_cancel()
+ cancel.cancel()
+ cancel.cancel()
+ pgconn.finish()
+ cancel.cancel()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.get_cancel()
+
+
+def test_cancel_free(pgconn):
+ cancel = pgconn.get_cancel()
+ cancel.free()
+ with pytest.raises(psycopg.OperationalError):
+ cancel.cancel()
+ cancel.free()
+
+
+@pytest.mark.crdb_skip("notify")
+def test_notify(pgconn):
+ assert pgconn.notifies() is None
+
+ pgconn.exec_(b"listen foo")
+ pgconn.exec_(b"listen bar")
+ pgconn.exec_(b"notify foo, '1'")
+ pgconn.exec_(b"notify bar, '2'")
+ pgconn.exec_(b"notify foo, '3'")
+
+ n = pgconn.notifies()
+ assert n.relname == b"foo"
+ assert n.be_pid == pgconn.backend_pid
+ assert n.extra == b"1"
+
+ n = pgconn.notifies()
+ assert n.relname == b"bar"
+ assert n.be_pid == pgconn.backend_pid
+ assert n.extra == b"2"
+
+ n = pgconn.notifies()
+ assert n.relname == b"foo"
+ assert n.be_pid == pgconn.backend_pid
+ assert n.extra == b"3"
+
+ assert pgconn.notifies() is None
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice_nohandler(pgconn):
+ pgconn.exec_(b"set client_min_messages to notice")
+ res = pgconn.exec_(
+ b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice(pgconn):
+ msgs = []
+
+ def callback(res):
+ assert res.status == pq.ExecStatus.NONFATAL_ERROR
+ msgs.append(res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY))
+
+ pgconn.exec_(b"set client_min_messages to notice")
+ pgconn.notice_handler = callback
+ res = pgconn.exec_(
+ b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
+ )
+
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert msgs and msgs[0] == b"hello notice"
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice_error(pgconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ def callback(res):
+ raise Exception("hello error")
+
+ pgconn.exec_(b"set client_min_messages to notice")
+ pgconn.notice_handler = callback
+ res = pgconn.exec_(
+ b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
+ )
+
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello error" in rec.message
+
+
+@pytest.mark.libpq("< 14")
+@pytest.mark.skipif("sys.platform != 'linux'")
+def test_trace_pre14(pgconn, tmp_path):
+ tracef = tmp_path / "trace"
+ with tracef.open("w") as f:
+ pgconn.trace(f.fileno())
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.set_trace_flags(0)
+ pgconn.exec_(b"select 1")
+ pgconn.untrace()
+ pgconn.exec_(b"select 2")
+ traces = tracef.read_text()
+ assert "select 1" in traces
+ assert "select 2" not in traces
+
+
+@pytest.mark.libpq(">= 14")
+@pytest.mark.skipif("sys.platform != 'linux'")
+def test_trace(pgconn, tmp_path):
+ tracef = tmp_path / "trace"
+ with tracef.open("w") as f:
+ pgconn.trace(f.fileno())
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ pgconn.exec_(b"select 1::int4 as foo")
+ pgconn.untrace()
+ pgconn.exec_(b"select 2::int4 as foo")
+ traces = [line.split("\t") for line in tracef.read_text().splitlines()]
+ assert traces == [
+ ["F", "26", "Query", ' "select 1::int4 as foo"'],
+ ["B", "28", "RowDescription", ' 1 "foo" NNNN 0 NNNN 4 -1 0'],
+ ["B", "11", "DataRow", " 1 1 '1'"],
+ ["B", "13", "CommandComplete", ' "SELECT 1"'],
+ ["B", "5", "ReadyForQuery", " I"],
+ ]
+
+
+@pytest.mark.skipif("sys.platform == 'linux'")
+def test_trace_nonlinux(pgconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.trace(1)
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password(pgconn):
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5")
+ assert enc == b"md594839d658c28a357126f105b9cb14cfc"
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password_scram(pgconn):
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256")
+ assert enc.startswith(b"SCRAM-SHA-256$")
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password_badalgo(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ assert pgconn.encrypt_password(b"psycopg2", b"ashesh", b"wat")
+
+
+@pytest.mark.libpq(">= 10")
+@pytest.mark.crdb_skip("password_encryption")
+def test_encrypt_password_query(pgconn):
+ res = pgconn.exec_(b"set password_encryption to 'md5'")
+ assert res.status == pq.ExecStatus.COMMAND_OK, pgconn.error_message.decode()
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh")
+ assert enc == b"md594839d658c28a357126f105b9cb14cfc"
+
+ res = pgconn.exec_(b"set password_encryption to 'scram-sha-256'")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh")
+ assert enc.startswith(b"SCRAM-SHA-256$")
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password_closed(pgconn):
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ assert pgconn.encrypt_password(b"psycopg2", b"ashesh")
+
+
+@pytest.mark.libpq("< 10")
+def test_encrypt_password_not_supported(pgconn):
+ # it might even be supported, but not worth the lifetime
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5")
+
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256")
+
+
+def test_str(pgconn, dsn):
+ assert "[IDLE]" in str(pgconn)
+ pgconn.finish()
+ assert "[BAD]" in str(pgconn)
+
+ pgconn2 = pq.PGconn.connect_start(dsn.encode())
+ assert "[" in str(pgconn2)
+ assert "[IDLE]" not in str(pgconn2)
diff --git a/tests/pq/test_pgresult.py b/tests/pq/test_pgresult.py
new file mode 100644
index 0000000..3ad818d
--- /dev/null
+++ b/tests/pq/test_pgresult.py
@@ -0,0 +1,207 @@
+import ctypes
+import pytest
+
+from psycopg import pq
+
+
+@pytest.mark.parametrize(
+ "command, status",
+ [
+ (b"", "EMPTY_QUERY"),
+ (b"select 1", "TUPLES_OK"),
+ (b"set timezone to utc", "COMMAND_OK"),
+ (b"wat", "FATAL_ERROR"),
+ ],
+)
+def test_status(pgconn, command, status):
+ res = pgconn.exec_(command)
+ assert res.status == getattr(pq.ExecStatus, status)
+ assert status in repr(res)
+
+
+def test_clear(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ res.clear()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ res.clear()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+
+
+def test_pgresult_ptr(pgconn, libpq):
+ res = pgconn.exec_(b"select 1")
+ assert isinstance(res.pgresult_ptr, int)
+
+ f = libpq.PQcmdStatus
+ f.argtypes = [ctypes.c_void_p]
+ f.restype = ctypes.c_char_p
+ assert f(res.pgresult_ptr) == b"SELECT 1"
+
+ res.clear()
+ assert res.pgresult_ptr is None
+
+
+def test_error_message(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.error_message == b""
+ res = pgconn.exec_(b"select wat")
+ assert b"wat" in res.error_message
+ res.clear()
+ assert res.error_message == b""
+
+
+def test_error_field(pgconn):
+ res = pgconn.exec_(b"select wat")
+ # https://github.com/cockroachdb/cockroach/issues/81794
+ assert (
+ res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED)
+ or res.error_field(pq.DiagnosticField.SEVERITY)
+ ) == b"ERROR"
+ assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703"
+ assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+ res.clear()
+ assert res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) is None
+
+
+@pytest.mark.parametrize("n", range(4))
+def test_ntuples(pgconn, n):
+ res = pgconn.exec_params(b"select generate_series(1, $1)", [str(n).encode("ascii")])
+ assert res.ntuples == n
+ res.clear()
+ assert res.ntuples == 0
+
+
+def test_nfields(pgconn):
+ res = pgconn.exec_(b"select wat")
+ assert res.nfields == 0
+ res = pgconn.exec_(b"select 1, 2, 3")
+ assert res.nfields == 3
+ res.clear()
+ assert res.nfields == 0
+
+
+def test_fname(pgconn):
+ res = pgconn.exec_(b'select 1 as foo, 2 as "BAR"')
+ assert res.fname(0) == b"foo"
+ assert res.fname(1) == b"BAR"
+ assert res.fname(2) is None
+ assert res.fname(-1) is None
+ res.clear()
+ assert res.fname(0) is None
+
+
+@pytest.mark.crdb("skip", reason="ftable")
+def test_ftable_and_col(pgconn):
+ res = pgconn.exec_(
+ b"""
+ drop table if exists t1, t2;
+ create table t1 as select 1 as f1;
+ create table t2 as select 2 as f2, 3 as f3;
+ """
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_(
+ b"select f1, f3, 't1'::regclass::oid, 't2'::regclass::oid from t1, t2"
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+
+ assert res.ftable(0) == int(res.get_value(0, 2).decode("ascii"))
+ assert res.ftable(1) == int(res.get_value(0, 3).decode("ascii"))
+ assert res.ftablecol(0) == 1
+ assert res.ftablecol(1) == 2
+ res.clear()
+ assert res.ftable(0) == 0
+ assert res.ftablecol(0) == 0
+
+
+@pytest.mark.parametrize("fmt", (0, 1))
+def test_fformat(pgconn, fmt):
+ res = pgconn.exec_params(b"select 1", [], result_format=fmt)
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.fformat(0) == fmt
+ assert res.binary_tuples == fmt
+ res.clear()
+ assert res.fformat(0) == 0
+ assert res.binary_tuples == 0
+
+
+def test_ftype(pgconn):
+ res = pgconn.exec_(b"select 1::int4, 1::numeric, 1::text")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.ftype(0) == 23
+ assert res.ftype(1) == 1700
+ assert res.ftype(2) == 25
+ res.clear()
+ assert res.ftype(0) == 0
+
+
+def test_fmod(pgconn):
+ res = pgconn.exec_(b"select 1::int, 1::numeric(10), 1::numeric(10,2)")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.fmod(0) == -1
+ assert res.fmod(1) == 0xA0004
+ assert res.fmod(2) == 0xA0006
+ res.clear()
+ assert res.fmod(0) == 0
+
+
+def test_fsize(pgconn):
+ res = pgconn.exec_(b"select 1::int4, 1::bigint, 1::text")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.fsize(0) == 4
+ assert res.fsize(1) == 8
+ assert res.fsize(2) == -1
+ res.clear()
+ assert res.fsize(0) == 0
+
+
+def test_get_value(pgconn):
+ res = pgconn.exec_(b"select 'a', '', NULL")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"a"
+ assert res.get_value(0, 1) == b""
+ assert res.get_value(0, 2) is None
+ res.clear()
+ assert res.get_value(0, 0) is None
+
+
+def test_nparams_types(pgconn):
+ res = pgconn.prepare(b"", b"select $1::int4, $2::text")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.describe_prepared(b"")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ assert res.nparams == 2
+ assert res.param_type(0) == 23
+ assert res.param_type(1) == 25
+
+ res.clear()
+ assert res.nparams == 0
+ assert res.param_type(0) == 0
+
+
+def test_command_status(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.command_status == b"SELECT 1"
+ res = pgconn.exec_(b"set timezone to utc")
+ assert res.command_status == b"SET"
+ res.clear()
+ assert res.command_status is None
+
+
+def test_command_tuples(pgconn):
+ res = pgconn.exec_(b"set timezone to utf8")
+ assert res.command_tuples is None
+ res = pgconn.exec_(b"select * from generate_series(1, 10)")
+ assert res.command_tuples == 10
+ res.clear()
+ assert res.command_tuples is None
+
+
+def test_oid_value(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.oid_value == 0
+ res.clear()
+ assert res.oid_value == 0
diff --git a/tests/pq/test_pipeline.py b/tests/pq/test_pipeline.py
new file mode 100644
index 0000000..00cd54a
--- /dev/null
+++ b/tests/pq/test_pipeline.py
@@ -0,0 +1,161 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+@pytest.mark.libpq("< 14")
+def test_old_libpq(pgconn):
+ assert pgconn.pipeline_status == 0
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.enter_pipeline_mode()
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.exit_pipeline_mode()
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.pipeline_sync()
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.send_flush_request()
+
+
+@pytest.mark.libpq(">= 14")
+def test_work_in_progress(pgconn):
+ assert not pgconn.nonblocking
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"select $1", [b"1"])
+ with pytest.raises(psycopg.OperationalError, match="cannot exit pipeline mode"):
+ pgconn.exit_pipeline_mode()
+
+
+@pytest.mark.libpq(">= 14")
+def test_multi_pipelines(pgconn):
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"select $1", [b"1"], param_types=[25])
+ pgconn.pipeline_sync()
+ pgconn.send_query_params(b"select $1", [b"2"], param_types=[25])
+ pgconn.pipeline_sync()
+
+ # result from first query
+ result1 = pgconn.get_result()
+ assert result1 is not None
+ assert result1.status == pq.ExecStatus.TUPLES_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # first sync result
+ sync_result = pgconn.get_result()
+ assert sync_result is not None
+ assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # result from second query
+ result2 = pgconn.get_result()
+ assert result2 is not None
+ assert result2.status == pq.ExecStatus.TUPLES_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # second sync result
+ sync_result = pgconn.get_result()
+ assert sync_result is not None
+ assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # pipeline still ON
+ assert pgconn.pipeline_status == pq.PipelineStatus.ON
+
+ pgconn.exit_pipeline_mode()
+
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+
+ assert result1.get_value(0, 0) == b"1"
+ assert result2.get_value(0, 0) == b"2"
+
+
+@pytest.mark.libpq(">= 14")
+def test_flush_request(pgconn):
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"select $1", [b"1"], param_types=[25])
+ pgconn.send_flush_request()
+ r = pgconn.get_result()
+ assert r.status == pq.ExecStatus.TUPLES_OK
+ assert r.get_value(0, 0) == b"1"
+ pgconn.exit_pipeline_mode()
+
+
+@pytest.fixture
+def table(pgconn):
+ tablename = "pipeline"
+ pgconn.exec_(f"create table {tablename} (s text)".encode("ascii"))
+ yield tablename
+ pgconn.exec_(f"drop table if exists {tablename}".encode("ascii"))
+
+
+@pytest.mark.libpq(">= 14")
+def test_pipeline_abort(pgconn, table):
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"insert into pipeline values ($1)", [b"1"])
+ pgconn.send_query_params(b"select no_such_function($1)", [b"1"])
+ pgconn.send_query_params(b"insert into pipeline values ($1)", [b"2"])
+ pgconn.pipeline_sync()
+ pgconn.send_query_params(b"insert into pipeline values ($1)", [b"3"])
+ pgconn.pipeline_sync()
+
+ # result from first INSERT
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.COMMAND_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # error result from second query (SELECT)
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.FATAL_ERROR
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # pipeline should be aborted, due to previous error
+ assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED
+
+ # result from second INSERT, aborted due to previous error
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.PIPELINE_ABORTED
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # pipeline is still aborted
+ assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED
+
+ # sync result
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # aborted flag is clear, pipeline is on again
+ assert pgconn.pipeline_status == pq.PipelineStatus.ON
+
+ # result from the third INSERT
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.COMMAND_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # second sync result
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ pgconn.exit_pipeline_mode()
diff --git a/tests/pq/test_pq.py b/tests/pq/test_pq.py
new file mode 100644
index 0000000..076c3b6
--- /dev/null
+++ b/tests/pq/test_pq.py
@@ -0,0 +1,57 @@
+import os
+
+import pytest
+
+import psycopg
+from psycopg import pq
+
+from ..utils import check_libpq_version
+
+
+def test_version():
+ rv = pq.version()
+ assert rv > 90500
+ assert rv < 200000 # you are good for a while
+
+
+def test_build_version():
+ assert pq.__build_version__ and pq.__build_version__ >= 70400
+
+
+@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_BUILD')")
+def test_want_built_version():
+ want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_BUILD"]
+ got = pq.__build_version__
+ assert not check_libpq_version(got, want)
+
+
+@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_IMPORT')")
+def test_want_import_version():
+ want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_IMPORT"]
+ got = pq.version()
+ assert not check_libpq_version(got, want)
+
+
+# Note: These tests are here because test_pipeline.py tests are all skipped
+# when pipeline mode is not supported.
+
+
+@pytest.mark.libpq(">= 14")
+def test_pipeline_supported(conn):
+ assert psycopg.Pipeline.is_supported()
+ assert psycopg.AsyncPipeline.is_supported()
+
+ with conn.pipeline():
+ pass
+
+
+@pytest.mark.libpq("< 14")
+def test_pipeline_not_supported(conn):
+ assert not psycopg.Pipeline.is_supported()
+ assert not psycopg.AsyncPipeline.is_supported()
+
+ with pytest.raises(psycopg.NotSupportedError) as exc:
+ with conn.pipeline():
+ pass
+
+ assert "too old" in str(exc.value)
diff --git a/tests/scripts/bench-411.py b/tests/scripts/bench-411.py
new file mode 100644
index 0000000..82ea451
--- /dev/null
+++ b/tests/scripts/bench-411.py
@@ -0,0 +1,300 @@
+import os
+import sys
+import time
+import random
+import asyncio
+import logging
+from enum import Enum
+from typing import Any, Dict, List, Generator
+from argparse import ArgumentParser, Namespace
+from contextlib import contextmanager
+
+logger = logging.getLogger()
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(message)s",
+)
+
+
+class Driver(str, Enum):
+ psycopg2 = "psycopg2"
+ psycopg = "psycopg"
+ psycopg_async = "psycopg_async"
+ asyncpg = "asyncpg"
+
+
+ids: List[int] = []
+data: List[Dict[str, Any]] = []
+
+
+def main() -> None:
+
+ args = parse_cmdline()
+
+ ids[:] = range(args.ntests)
+ data[:] = [
+ dict(
+ id=i,
+ name="c%d" % i,
+ description="c%d" % i,
+ q=i * 10,
+ p=i * 20,
+ x=i * 30,
+ y=i * 40,
+ )
+ for i in ids
+ ]
+
+ # Must be done just on end
+ drop_at_the_end = args.drop
+ args.drop = False
+
+ for i, name in enumerate(args.drivers):
+ if i == len(args.drivers) - 1:
+ args.drop = drop_at_the_end
+
+ if name == Driver.psycopg2:
+ import psycopg2 # type: ignore
+
+ run_psycopg2(psycopg2, args)
+
+ elif name == Driver.psycopg:
+ import psycopg
+
+ run_psycopg(psycopg, args)
+
+ elif name == Driver.psycopg_async:
+ import psycopg
+
+ if sys.platform == "win32":
+ if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
+ asyncio.set_event_loop_policy(
+ asyncio.WindowsSelectorEventLoopPolicy()
+ )
+
+ asyncio.run(run_psycopg_async(psycopg, args))
+
+ elif name == Driver.asyncpg:
+ import asyncpg # type: ignore
+
+ asyncio.run(run_asyncpg(asyncpg, args))
+
+ else:
+ raise AssertionError(f"unknown driver: {name!r}")
+
+ # Must be done just on start
+ args.create = False
+
+
+table = """
+CREATE TABLE customer (
+ id SERIAL NOT NULL,
+ name VARCHAR(255),
+ description VARCHAR(255),
+ q INTEGER,
+ p INTEGER,
+ x INTEGER,
+ y INTEGER,
+ z INTEGER,
+ PRIMARY KEY (id)
+)
+"""
+drop = "DROP TABLE IF EXISTS customer"
+
+insert = """
+INSERT INTO customer (id, name, description, q, p, x, y) VALUES
+(%(id)s, %(name)s, %(description)s, %(q)s, %(p)s, %(x)s, %(y)s)
+"""
+
+select = """
+SELECT customer.id, customer.name, customer.description, customer.q,
+ customer.p, customer.x, customer.y, customer.z
+FROM customer
+WHERE customer.id = %(id)s
+"""
+
+
+@contextmanager
+def time_log(message: str) -> Generator[None, None, None]:
+ start = time.monotonic()
+ yield
+ end = time.monotonic()
+ logger.info(f"Run {message} in {end-start} s")
+
+
+def run_psycopg2(psycopg2: Any, args: Namespace) -> None:
+ logger.info("Running psycopg2")
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ with psycopg2.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ cursor.execute(table)
+ cursor.executemany(insert, data)
+ conn.commit()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ with psycopg2.connect(args.dsn) as conn:
+ with time_log("psycopg2"):
+ for id_ in to_query:
+ with conn.cursor() as cursor:
+ cursor.execute(select, {"id": id_})
+ cursor.fetchall()
+ # conn.rollback()
+
+ if args.drop:
+ logger.info("dropping test records")
+ with psycopg2.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ conn.commit()
+
+
+def run_psycopg(psycopg: Any, args: Namespace) -> None:
+ logger.info("Running psycopg sync")
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ with psycopg.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ cursor.execute(table)
+ cursor.executemany(insert, data)
+ conn.commit()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ with psycopg.connect(args.dsn) as conn:
+ with time_log("psycopg"):
+ for id_ in to_query:
+ with conn.cursor() as cursor:
+ cursor.execute(select, {"id": id_})
+ cursor.fetchall()
+ # conn.rollback()
+
+ if args.drop:
+ logger.info("dropping test records")
+ with psycopg.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ conn.commit()
+
+
+async def run_psycopg_async(psycopg: Any, args: Namespace) -> None:
+ logger.info("Running psycopg async")
+
+ conn: Any
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
+ async with conn.cursor() as cursor:
+ await cursor.execute(drop)
+ await cursor.execute(table)
+ await cursor.executemany(insert, data)
+ await conn.commit()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
+ with time_log("psycopg_async"):
+ for id_ in to_query:
+ cursor = await conn.execute(select, {"id": id_})
+ await cursor.fetchall()
+ await cursor.close()
+ # await conn.rollback()
+
+ if args.drop:
+ logger.info("dropping test records")
+ async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
+ async with conn.cursor() as cursor:
+ await cursor.execute(drop)
+ await conn.commit()
+
+
+async def run_asyncpg(asyncpg: Any, args: Namespace) -> None:
+ logger.info("Running asyncpg")
+
+ places = dict(id="$1", name="$2", description="$3", q="$4", p="$5", x="$6", y="$7")
+ a_insert = insert % places
+ a_select = select % {"id": "$1"}
+
+ conn: Any
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ conn = await asyncpg.connect(args.dsn)
+ async with conn.transaction():
+ await conn.execute(drop)
+ await conn.execute(table)
+ await conn.executemany(a_insert, [tuple(d.values()) for d in data])
+ await conn.close()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ conn = await asyncpg.connect(args.dsn)
+ with time_log("asyncpg"):
+ for id_ in to_query:
+ tr = conn.transaction()
+ await tr.start()
+ await conn.fetch(a_select, id_)
+ # await tr.rollback()
+ await conn.close()
+
+ if args.drop:
+ logger.info("dropping test records")
+ conn = await asyncpg.connect(args.dsn)
+ async with conn.transaction():
+ await conn.execute(drop)
+ await conn.close()
+
+
+def parse_cmdline() -> Namespace:
+ parser = ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "drivers",
+ nargs="+",
+ metavar="DRIVER",
+ type=Driver,
+ help=f"the drivers to test [choices: {', '.join(d.value for d in Driver)}]",
+ )
+
+ parser.add_argument(
+ "--ntests",
+ type=int,
+ default=10_000,
+ help="number of tests to perform [default: %(default)s]",
+ )
+
+ parser.add_argument(
+ "--dsn",
+ default=os.environ.get("PSYCOPG_TEST_DSN", ""),
+ help="database connection string"
+ " [default: %(default)r (from PSYCOPG_TEST_DSN env var)]",
+ )
+
+ parser.add_argument(
+ "--no-create",
+ dest="create",
+ action="store_false",
+ default="True",
+ help="skip data creation before tests (it must exist already)",
+ )
+
+ parser.add_argument(
+ "--no-drop",
+ dest="drop",
+ action="store_false",
+ default="True",
+ help="skip data drop after tests",
+ )
+
+ opt = parser.parse_args()
+
+ return opt
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/scripts/dectest.py b/tests/scripts/dectest.py
new file mode 100644
index 0000000..a49f116
--- /dev/null
+++ b/tests/scripts/dectest.py
@@ -0,0 +1,51 @@
+"""
+A quick and rough performance comparison of text vs. binary Decimal adaptation
+"""
+from random import randrange
+from decimal import Decimal
+import psycopg
+from psycopg import sql
+
+ncols = 10
+nrows = 500000
+format = psycopg.pq.Format.BINARY
+test = "copy"
+
+
+def main() -> None:
+ cnn = psycopg.connect()
+
+ cnn.execute(
+ sql.SQL("create table testdec ({})").format(
+ sql.SQL(", ").join(
+ [
+ sql.SQL("{} numeric(10,2)").format(sql.Identifier(f"t{i}"))
+ for i in range(ncols)
+ ]
+ )
+ )
+ )
+ cur = cnn.cursor()
+
+ if test == "copy":
+ with cur.copy(f"copy testdec from stdin (format {format.name})") as copy:
+ for j in range(nrows):
+ copy.write_row(
+ [Decimal(randrange(10000000000)) / 100 for i in range(ncols)]
+ )
+
+ elif test == "insert":
+ ph = ["%t", "%b"][format]
+ cur.executemany(
+ "insert into testdec values (%s)" % ", ".join([ph] * ncols),
+ (
+ [Decimal(randrange(10000000000)) / 100 for i in range(ncols)]
+ for j in range(nrows)
+ ),
+ )
+ else:
+ raise Exception(f"bad test: {test}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py
new file mode 100644
index 0000000..ec95229
--- /dev/null
+++ b/tests/scripts/pipeline-demo.py
@@ -0,0 +1,340 @@
+"""Pipeline mode demo
+
+This reproduces libpq_pipeline::pipelined_insert PostgreSQL test at
+src/test/modules/libpq_pipeline/libpq_pipeline.c::test_pipelined_insert().
+
+We do not fetch results explicitly (using cursor.fetch*()), this is
+handled by execute() calls when pgconn socket is read-ready, which
+happens when the output buffer is full.
+"""
+import argparse
+import asyncio
+import logging
+from contextlib import contextmanager
+from functools import partial
+from typing import Any, Iterator, Optional, Sequence, Tuple
+
+from psycopg import AsyncConnection, Connection
+from psycopg import pq, waiting
+from psycopg import errors as e
+from psycopg.abc import PipelineCommand
+from psycopg.generators import pipeline_communicate
+from psycopg.pq import Format, DiagnosticField
+from psycopg._compat import Deque
+
+psycopg_logger = logging.getLogger("psycopg")
+pipeline_logger = logging.getLogger("pipeline")
+args: argparse.Namespace
+
+
+class LoggingPGconn:
+ """Wrapper for PGconn that logs fetched results."""
+
+ def __init__(self, pgconn: pq.abc.PGconn, logger: logging.Logger):
+ self._pgconn = pgconn
+ self._logger = logger
+
+ def log_notice(result: pq.abc.PGresult) -> None:
+ def get_field(field: DiagnosticField) -> Optional[str]:
+ value = result.error_field(field)
+ return value.decode("utf-8", "replace") if value else None
+
+ logger.info(
+ "notice %s %s",
+ get_field(DiagnosticField.SEVERITY),
+ get_field(DiagnosticField.MESSAGE_PRIMARY),
+ )
+
+ pgconn.notice_handler = log_notice
+
+ if args.trace:
+ self._trace_file = open(args.trace, "w")
+ pgconn.trace(self._trace_file.fileno())
+
+ def __del__(self) -> None:
+ if hasattr(self, "_trace_file"):
+ self._pgconn.untrace()
+ self._trace_file.close()
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._pgconn, name)
+
+ def send_query(self, command: bytes) -> None:
+ self._logger.warning("PQsendQuery broken in libpq 14.5")
+ self._pgconn.send_query(command)
+ self._logger.info("sent %s", command.decode())
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ self._pgconn.send_query_params(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._logger.info("sent %s", command.decode())
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ self._pgconn.send_query_prepared(
+ name, param_values, param_formats, result_format
+ )
+ self._logger.info("sent prepared '%s' with %s", name.decode(), param_values)
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ self._pgconn.send_prepare(name, command, param_types)
+ self._logger.info("prepare %s as '%s'", command.decode(), name.decode())
+
+ def get_result(self) -> Optional[pq.abc.PGresult]:
+ r = self._pgconn.get_result()
+ if r is not None:
+ self._logger.info("got %s result", pq.ExecStatus(r.status).name)
+ return r
+
+
+@contextmanager
+def prepare_pipeline_demo_pq(
+ pgconn: LoggingPGconn, rows_to_send: int, logger: logging.Logger
+) -> Iterator[Tuple[Deque[PipelineCommand], Deque[str]]]:
+ """Set up pipeline demo with initial queries and yield commands and
+ results queue for pipeline_communicate().
+ """
+ logger.debug("enter pipeline")
+ pgconn.enter_pipeline_mode()
+
+ setup_queries = [
+ ("begin", "BEGIN TRANSACTION"),
+ ("drop table", "DROP TABLE IF EXISTS pq_pipeline_demo"),
+ (
+ "create table",
+ (
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ ),
+ ),
+ (
+ "prepare",
+ ("INSERT INTO pq_pipeline_demo(itemno, int8filler)" " VALUES ($1, $2)"),
+ ),
+ ]
+
+ commands = Deque[PipelineCommand]()
+ results_queue = Deque[str]()
+
+ for qname, query in setup_queries:
+ if qname == "prepare":
+ pgconn.send_prepare(qname.encode(), query.encode())
+ else:
+ pgconn.send_query_params(query.encode(), None)
+ results_queue.append(qname)
+
+ committed = False
+ synced = False
+
+ while True:
+ if rows_to_send:
+ params = [f"{rows_to_send}".encode(), f"{1 << 62}".encode()]
+ commands.append(partial(pgconn.send_query_prepared, b"prepare", params))
+ results_queue.append(f"row {rows_to_send}")
+ rows_to_send -= 1
+
+ elif not committed:
+ committed = True
+ commands.append(partial(pgconn.send_query_params, b"COMMIT", None))
+ results_queue.append("commit")
+
+ elif not synced:
+
+ def sync() -> None:
+ pgconn.pipeline_sync()
+ logger.info("pipeline sync sent")
+
+ synced = True
+ commands.append(sync)
+ results_queue.append("sync")
+
+ else:
+ break
+
+ try:
+ yield commands, results_queue
+ finally:
+ logger.debug("exit pipeline")
+ pgconn.exit_pipeline_mode()
+
+
+def pipeline_demo_pq(rows_to_send: int, logger: logging.Logger) -> None:
+ pgconn = LoggingPGconn(Connection.connect().pgconn, logger)
+ with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as (
+ commands,
+ results_queue,
+ ):
+ while results_queue:
+ fetched = waiting.wait(
+ pipeline_communicate(
+ pgconn, # type: ignore[arg-type]
+ commands,
+ ),
+ pgconn.socket,
+ )
+ assert not commands, commands
+ for results in fetched:
+ results_queue.popleft()
+ for r in results:
+ if r.status in (
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ ):
+ raise e.error_from_result(r)
+
+
+async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> None:
+ pgconn = LoggingPGconn((await AsyncConnection.connect()).pgconn, logger)
+
+ with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as (
+ commands,
+ results_queue,
+ ):
+ while results_queue:
+ fetched = await waiting.wait_async(
+ pipeline_communicate(
+ pgconn, # type: ignore[arg-type]
+ commands,
+ ),
+ pgconn.socket,
+ )
+ assert not commands, commands
+ for results in fetched:
+ results_queue.popleft()
+ for r in results:
+ if r.status in (
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ ):
+ raise e.error_from_result(r)
+
+
+def pipeline_demo(rows_to_send: int, many: bool, logger: logging.Logger) -> None:
+ """Pipeline demo using sync API."""
+ conn = Connection.connect()
+ conn.autocommit = True
+ conn.pgconn = LoggingPGconn(conn.pgconn, logger) # type: ignore[assignment]
+ with conn.pipeline():
+ with conn.transaction():
+ conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+ conn.execute(
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ )
+ query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)"
+ params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1))
+ if many:
+ cur = conn.cursor()
+ cur.executemany(query, list(params))
+ else:
+ for p in params:
+ conn.execute(query, p)
+
+
+async def pipeline_demo_async(
+ rows_to_send: int, many: bool, logger: logging.Logger
+) -> None:
+ """Pipeline demo using async API."""
+ aconn = await AsyncConnection.connect()
+ await aconn.set_autocommit(True)
+ aconn.pgconn = LoggingPGconn(aconn.pgconn, logger) # type: ignore[assignment]
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+ await aconn.execute(
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ )
+ query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)"
+ params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1))
+ if many:
+ cur = aconn.cursor()
+ await cur.executemany(query, list(params))
+ else:
+ for p in params:
+ await aconn.execute(query, p)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-n",
+ dest="nrows",
+ metavar="ROWS",
+ default=10_000,
+ type=int,
+ help="number of rows to insert",
+ )
+ parser.add_argument(
+ "--pq", action="store_true", help="use low-level psycopg.pq API"
+ )
+ parser.add_argument(
+ "--async", dest="async_", action="store_true", help="use async API"
+ )
+ parser.add_argument(
+ "--many",
+ action="store_true",
+ help="use executemany() (not applicable for --pq)",
+ )
+ parser.add_argument("--trace", help="write trace info into TRACE file")
+ parser.add_argument("-l", "--log", help="log file (stderr by default)")
+
+ global args
+ args = parser.parse_args()
+
+ psycopg_logger.setLevel(logging.DEBUG)
+ pipeline_logger.setLevel(logging.DEBUG)
+ if args.log:
+ psycopg_logger.addHandler(logging.FileHandler(args.log))
+ pipeline_logger.addHandler(logging.FileHandler(args.log))
+ else:
+ psycopg_logger.addHandler(logging.StreamHandler())
+ pipeline_logger.addHandler(logging.StreamHandler())
+
+ if args.pq:
+ if args.many:
+ parser.error("--many cannot be used with --pq")
+ if args.async_:
+ asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger))
+ else:
+ pipeline_demo_pq(args.nrows, pipeline_logger)
+ else:
+ if pq.__impl__ != "python":
+ parser.error(
+ "only supported for Python implementation (set PSYCOPG_IMPL=python)"
+ )
+ if args.async_:
+ asyncio.run(pipeline_demo_async(args.nrows, args.many, pipeline_logger))
+ else:
+ pipeline_demo(args.nrows, args.many, pipeline_logger)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/scripts/spiketest.py b/tests/scripts/spiketest.py
new file mode 100644
index 0000000..2c9cc16
--- /dev/null
+++ b/tests/scripts/spiketest.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python
+"""
+Run a connection pool spike test.
+
+The test is inspired to the `spike analysis`__ illustrated by HikariCP
+
+.. __: https://github.com/brettwooldridge/HikariCP/blob/dev/documents/
+ Welcome-To-The-Jungle.md
+
+"""
+# mypy: allow-untyped-defs
+# mypy: allow-untyped-calls
+
+import time
+import threading
+
+import psycopg
+import psycopg_pool
+from psycopg.rows import Row
+
+import logging
+
+
+def main() -> None:
+ opt = parse_cmdline()
+ if opt.loglevel:
+ loglevel = getattr(logging, opt.loglevel.upper())
+ logging.basicConfig(
+ level=loglevel, format="%(asctime)s %(levelname)s %(message)s"
+ )
+
+ logging.getLogger("psycopg2.pool").setLevel(loglevel)
+
+ with psycopg_pool.ConnectionPool(
+ opt.dsn,
+ min_size=opt.min_size,
+ max_size=opt.max_size,
+ connection_class=DelayedConnection,
+ kwargs={"conn_delay": 0.150},
+ ) as pool:
+ pool.wait()
+ measurer = Measurer(pool)
+
+ # Create and start all the thread: they will get stuck on the event
+ ev = threading.Event()
+ threads = [
+ threading.Thread(target=worker, args=(pool, 0.002, ev), daemon=True)
+ for i in range(opt.num_clients)
+ ]
+ for t in threads:
+ t.start()
+ time.sleep(0.2)
+
+ # Release the threads!
+ measurer.start(0.00025)
+ t0 = time.time()
+ ev.set()
+
+ # Wait for the threads to finish
+ for t in threads:
+ t.join()
+ t1 = time.time()
+ measurer.stop()
+
+ print(f"time: {(t1 - t0) * 1000} msec")
+ print("active,idle,total,waiting")
+ recs = [
+ f'{m["pool_size"] - m["pool_available"]}'
+ f',{m["pool_available"]}'
+ f',{m["pool_size"]}'
+ f',{m["requests_waiting"]}'
+ for m in measurer.measures
+ ]
+ print("\n".join(recs))
+
+
+def worker(p, t, ev):
+ ev.wait()
+ with p.connection():
+ time.sleep(t)
+
+
+class Measurer:
+ def __init__(self, pool):
+ self.pool = pool
+ self.worker = None
+ self.stopped = False
+ self.measures = []
+
+ def start(self, interval):
+ self.worker = threading.Thread(target=self._run, args=(interval,), daemon=True)
+ self.worker.start()
+
+ def stop(self):
+ self.stopped = True
+ if self.worker:
+ self.worker.join()
+ self.worker = None
+
+ def _run(self, interval):
+ while not self.stopped:
+ self.measures.append(self.pool.get_stats())
+ time.sleep(interval)
+
+
+class DelayedConnection(psycopg.Connection[Row]):
+ """A connection adding a delay to the connection time."""
+
+ @classmethod
+ def connect(cls, conninfo, conn_delay=0, **kwargs):
+ t0 = time.time()
+ conn = super().connect(conninfo, **kwargs)
+ t1 = time.time()
+ wait = max(0.0, conn_delay - (t1 - t0))
+ if wait:
+ time.sleep(wait)
+ return conn
+
+
+def parse_cmdline():
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser(description=__doc__)
+ parser.add_argument("--dsn", default="", help="connection string to the database")
+ parser.add_argument(
+ "--min_size",
+ default=5,
+ type=int,
+ help="minimum number of connections in the pool",
+ )
+ parser.add_argument(
+ "--max_size",
+ default=20,
+ type=int,
+ help="maximum number of connections in the pool",
+ )
+ parser.add_argument(
+ "--num-clients",
+ default=50,
+ type=int,
+ help="number of threads making a request",
+ )
+ parser.add_argument(
+ "--loglevel",
+ default=None,
+ choices=("DEBUG", "INFO", "WARNING", "ERROR"),
+ help="level to log at [default: no log]",
+ )
+
+ opt = parser.parse_args()
+
+ return opt
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/test_adapt.py b/tests/test_adapt.py
new file mode 100644
index 0000000..2190a84
--- /dev/null
+++ b/tests/test_adapt.py
@@ -0,0 +1,530 @@
+import datetime as dt
+from types import ModuleType
+from typing import Any, List
+
+import pytest
+
+import psycopg
+from psycopg import pq, sql, postgres
+from psycopg import errors as e
+from psycopg.adapt import Transformer, PyFormat, Dumper, Loader
+from psycopg._cmodule import _psycopg
+from psycopg.postgres import types as builtins, TEXT_OID
+from psycopg.types.array import ListDumper, ListBinaryDumper
+
+
+@pytest.mark.parametrize(
+ "data, format, result, type",
+ [
+ (1, PyFormat.TEXT, b"1", "int2"),
+ ("hello", PyFormat.TEXT, b"hello", "text"),
+ ("hello", PyFormat.BINARY, b"hello", "text"),
+ ],
+)
+def test_dump(data, format, result, type):
+ t = Transformer()
+ dumper = t.get_dumper(data, format)
+ assert dumper.dump(data) == result
+ if type == "text" and format != PyFormat.BINARY:
+ assert dumper.oid == 0
+ else:
+ assert dumper.oid == builtins[type].oid
+
+
+@pytest.mark.parametrize(
+ "data, result",
+ [
+ (1, b"1"),
+ ("hello", b"'hello'"),
+ ("he'llo", b"'he''llo'"),
+ (True, b"true"),
+ (None, b"NULL"),
+ ],
+)
+def test_quote(data, result):
+ t = Transformer()
+ dumper = t.get_dumper(data, PyFormat.TEXT)
+ assert dumper.quote(data) == result
+
+
+def test_register_dumper_by_class(conn):
+ dumper = make_dumper("x")
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
+ conn.adapters.register_dumper(MyStr, dumper)
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
+
+
+def test_register_dumper_by_class_name(conn):
+ dumper = make_dumper("x")
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
+ conn.adapters.register_dumper(f"{MyStr.__module__}.{MyStr.__qualname__}", dumper)
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
+
+
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
+def test_dump_global_ctx(conn_cls, dsn, global_adapters, pgconn):
+ psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ psycopg.adapters.register_dumper(MyStr, make_dumper("gt"))
+ with conn_cls.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+
+
+def test_dump_connection_ctx(conn):
+ conn.adapters.register_dumper(MyStr, make_bin_dumper("b"))
+ conn.adapters.register_dumper(MyStr, make_dumper("t"))
+
+ cur = conn.cursor()
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellob",)
+
+
+def test_dump_cursor_ctx(conn):
+ conn.adapters.register_dumper(str, make_bin_dumper("b"))
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(str, make_bin_dumper("bc"))
+ cur.adapters.register_dumper(str, make_dumper("tc"))
+
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellotc",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellotc",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellobc",)
+
+ cur = conn.cursor()
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellob",)
+
+
+def test_dump_subclass(conn):
+ class MyString(str):
+ pass
+
+ cur = conn.cursor()
+ cur.execute("select %s::text, %b::text", [MyString("hello"), MyString("world")])
+ assert cur.fetchone() == ("hello", "world")
+
+
+def test_subclass_dumper(conn):
+ # This might be a C fast object: make sure that the Python code is called
+ from psycopg.types.string import StrDumper
+
+ class MyStrDumper(StrDumper):
+ def dump(self, obj):
+ return (obj * 2).encode()
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+ assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello"
+
+
+def test_dumper_protocol(conn):
+
+ # This class doesn't inherit from adapt.Dumper but passes a mypy check
+ from .adapters_example import MyStrDumper
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+ cur = conn.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellohello"
+ cur = conn.execute("select %s", [["hi", "ha"]])
+ assert cur.fetchone()[0] == ["hihi", "haha"]
+ assert sql.Literal("hello").as_string(conn) == "'qelloqello'"
+
+
+def test_loader_protocol(conn):
+
+ # This class doesn't inherit from adapt.Loader but passes a mypy check
+ from .adapters_example import MyTextLoader
+
+ conn.adapters.register_loader("text", MyTextLoader)
+ cur = conn.execute("select 'hello'::text")
+ assert cur.fetchone()[0] == "hellohello"
+ cur = conn.execute("select '{hi,ha}'::text[]")
+ assert cur.fetchone()[0] == ["hihi", "haha"]
+
+
+def test_subclass_loader(conn):
+ # This might be a C fast object: make sure that the Python code is called
+ from psycopg.types.string import TextLoader
+
+ class MyTextLoader(TextLoader):
+ def load(self, data):
+ return (bytes(data) * 2).decode()
+
+ conn.adapters.register_loader("text", MyTextLoader)
+ assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello"
+
+
+@pytest.mark.parametrize(
+ "data, format, type, result",
+ [
+ (b"1", pq.Format.TEXT, "int4", 1),
+ (b"hello", pq.Format.TEXT, "text", "hello"),
+ (b"hello", pq.Format.BINARY, "text", "hello"),
+ ],
+)
+def test_cast(data, format, type, result):
+ t = Transformer()
+ rv = t.get_loader(builtins[type].oid, format).load(data)
+ assert rv == result
+
+
+def test_register_loader_by_oid(conn):
+ assert TEXT_OID == 25
+ loader = make_loader("x")
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader
+ conn.adapters.register_loader(TEXT_OID, loader)
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
+
+
+def test_register_loader_by_type_name(conn):
+ loader = make_loader("x")
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader
+ conn.adapters.register_loader("text", loader)
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
+
+
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
+def test_load_global_ctx(conn_cls, dsn, global_adapters):
+ psycopg.adapters.register_loader("text", make_loader("gt"))
+ psycopg.adapters.register_loader("text", make_bin_loader("gb"))
+ with conn_cls.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
+
+
+def test_load_connection_ctx(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+ conn.adapters.register_loader("text", make_bin_loader("b"))
+
+ r = conn.cursor(binary=False).execute("select 'hello'::text").fetchone()
+ assert r == ("hellot",)
+ r = conn.cursor(binary=True).execute("select 'hello'::text").fetchone()
+ assert r == ("hellob",)
+
+
+def test_load_cursor_ctx(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+ conn.adapters.register_loader("text", make_bin_loader("b"))
+
+ cur = conn.cursor()
+ cur.adapters.register_loader("text", make_loader("tc"))
+ cur.adapters.register_loader("text", make_bin_loader("bc"))
+
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellotc",)
+ cur.format = pq.Format.BINARY
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellobc",)
+
+ cur = conn.cursor()
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellot",)
+ cur.format = pq.Format.BINARY
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellob",)
+
+
+def test_cow_dumpers(conn):
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+ cur2.adapters.register_dumper(str, make_dumper("c2"))
+
+ r = cur1.execute("select %s::text -- 1", ["hello"]).fetchone()
+ assert r == ("hellot",)
+ r = cur2.execute("select %s::text -- 1", ["hello"]).fetchone()
+ assert r == ("helloc2",)
+
+ conn.adapters.register_dumper(str, make_dumper("t1"))
+ r = cur1.execute("select %s::text -- 2", ["hello"]).fetchone()
+ assert r == ("hellot",)
+ r = cur2.execute("select %s::text -- 2", ["hello"]).fetchone()
+ assert r == ("helloc2",)
+
+
+def test_cow_loaders(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+ cur2.adapters.register_loader("text", make_loader("c2"))
+
+ assert cur1.execute("select 'hello'::text").fetchone() == ("hellot",)
+ assert cur2.execute("select 'hello'::text").fetchone() == ("helloc2",)
+
+ conn.adapters.register_loader("text", make_loader("t1"))
+ assert cur1.execute("select 'hello2'::text").fetchone() == ("hello2t",)
+ assert cur2.execute("select 'hello2'::text").fetchone() == ("hello2c2",)
+
+
+@pytest.mark.parametrize(
+ "sql, obj",
+ [("'{hello}'::text[]", ["helloc"]), ("row('hello'::text)", ("helloc",))],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out):
+ cur = conn.cursor(binary=fmt_out == pq.Format.BINARY)
+ if fmt_out == pq.Format.TEXT:
+ cur.adapters.register_loader("text", make_loader("c"))
+ else:
+ cur.adapters.register_loader("text", make_bin_loader("c"))
+
+ cur.execute(f"select {sql}")
+ res = cur.fetchone()[0]
+ assert res == obj
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_list_dumper(conn, fmt_out):
+ t = Transformer(conn)
+ fmt_in = PyFormat.from_pq(fmt_out)
+ dint = t.get_dumper([0], fmt_in)
+ assert isinstance(dint, (ListDumper, ListBinaryDumper))
+ assert dint.oid == builtins["int2"].array_oid
+ assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
+
+ dstr = t.get_dumper([""], fmt_in)
+ assert dstr is not dint
+
+ assert t.get_dumper([1], fmt_in) is dint
+ assert t.get_dumper([None, [1]], fmt_in) is dint
+
+ dempty = t.get_dumper([], fmt_in)
+ assert t.get_dumper([None, [None]], fmt_in) is dempty
+ assert dempty.oid == 0
+ assert dempty.dump([]) == b"{}"
+
+ L: List[List[Any]] = []
+ L.append(L)
+ with pytest.raises(psycopg.DataError):
+ assert t.get_dumper(L, fmt_in)
+
+
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == 0
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == 0
+
+
+def test_str_list_dumper_binary(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.BINARY)
+ assert isinstance(dstr, ListBinaryDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
+def test_last_dumper_registered_ctx(conn):
+ cur = conn.cursor()
+
+ bd = make_bin_dumper("b")
+ cur.adapters.register_dumper(str, bd)
+ td = make_dumper("t")
+ cur.adapters.register_dumper(str, td)
+
+ assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellot"
+ assert cur.execute("select %t", ["hello"]).fetchone()[0] == "hellot"
+ assert cur.execute("select %b", ["hello"]).fetchone()[0] == "hellob"
+
+ cur.adapters.register_dumper(str, bd)
+ assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellob"
+
+
+@pytest.mark.parametrize("fmt_in", [PyFormat.TEXT, PyFormat.BINARY])
+def test_none_type_argument(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table none_args (id serial primary key, num integer)")
+ cur.execute("insert into none_args (num) values (%s) returning id", (None,))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_return_untyped(conn, fmt_in):
+ # Analyze and check for changes using strings in untyped/typed contexts
+ cur = conn.cursor()
+ # Currently string are passed as unknown oid to libpq. This is because
+ # unknown is more easily cast by postgres to different types (see jsonb
+ # later).
+ cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10])
+ assert cur.fetchone() == ("hello", 10)
+
+ cur.execute("create table testjson(data jsonb)")
+ if fmt_in != PyFormat.BINARY:
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+ assert cur.execute("select data from testjson").fetchone() == ({},)
+ else:
+ # Binary types cannot be passed as unknown oids.
+ with pytest.raises(e.DatatypeMismatch):
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_no_cast_needed(conn, fmt_in):
+ # Verify that there is no need of cast in certain common scenario
+ cur = conn.execute(f"select '2021-01-01'::date + %{fmt_in.value}", [3])
+ assert cur.fetchone()[0] == dt.date(2021, 1, 4)
+
+ cur = conn.execute(f"select '[10, 20, 30]'::jsonb -> %{fmt_in.value}", [1])
+ assert cur.fetchone()[0] == 20
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(_psycopg is None, reason="C module test")
+def test_optimised_adapters():
+
+ # All the optimised adapters available
+ c_adapters = {}
+ for n in dir(_psycopg):
+ if n.startswith("_") or n in ("CDumper", "CLoader"):
+ continue
+ obj = getattr(_psycopg, n)
+ if not isinstance(obj, type):
+ continue
+ if not issubclass(
+ obj,
+ (_psycopg.CDumper, _psycopg.CLoader), # type: ignore[attr-defined]
+ ):
+ continue
+ c_adapters[n] = obj
+
+ # All the registered adapters
+ reg_adapters = set()
+ adapters = list(postgres.adapters._dumpers.values()) + postgres.adapters._loaders
+ assert len(adapters) == 5
+ for m in adapters:
+ reg_adapters |= set(m.values())
+
+ # Check that the registered adapters are the optimised one
+ i = 0
+ for cls in reg_adapters:
+ if cls.__name__ in c_adapters:
+ assert cls is c_adapters[cls.__name__]
+ i += 1
+
+ assert i >= 10
+
+ # Check that every optimised adapter is the optimised version of a Py one
+ for n in dir(psycopg.types):
+ mod = getattr(psycopg.types, n)
+ if not isinstance(mod, ModuleType):
+ continue
+ for n1 in dir(mod):
+ obj = getattr(mod, n1)
+ if not isinstance(obj, type):
+ continue
+ if not issubclass(obj, (Dumper, Loader)):
+ continue
+ c_adapters.pop(obj.__name__, None)
+
+ assert not c_adapters
+
+
+def test_dumper_init_error(conn):
+ class BadDumper(Dumper):
+ def __init__(self, cls, context):
+ super().__init__(cls, context)
+ 1 / 0
+
+ def dump(self, obj):
+ return obj.encode()
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(str, BadDumper)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select %s::text", ["hi"])
+
+
+def test_loader_init_error(conn):
+ class BadLoader(Loader):
+ def __init__(self, oid, context):
+ super().__init__(oid, context)
+ 1 / 0
+
+ def load(self, data):
+ return data.decode()
+
+ cur = conn.cursor()
+ cur.adapters.register_loader("text", BadLoader)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select 'hi'::text")
+ assert cur.fetchone() == ("hi",)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_random(conn, faker, fmt, fmt_out):
+ faker.format = fmt
+ faker.choose_schema(ncols=20)
+ faker.make_records(50)
+
+ with conn.cursor(binary=fmt_out) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+class MyStr(str):
+ pass
+
+
+def make_dumper(suffix):
+ """Create a test dumper appending a suffix to the bytes representation."""
+
+ class TestDumper(Dumper):
+ oid = TEXT_OID
+ format = pq.Format.TEXT
+
+ def dump(self, s):
+ return (s + suffix).encode("ascii")
+
+ return TestDumper
+
+
+def make_bin_dumper(suffix):
+ cls = make_dumper(suffix)
+ cls.format = pq.Format.BINARY
+ return cls
+
+
+def make_loader(suffix):
+ """Create a test loader appending a suffix to the data returned."""
+
+ class TestLoader(Loader):
+ format = pq.Format.TEXT
+
+ def load(self, b):
+ return bytes(b).decode("ascii") + suffix
+
+ return TestLoader
+
+
+def make_bin_loader(suffix):
+ cls = make_loader(suffix)
+ cls.format = pq.Format.BINARY
+ return cls
diff --git a/tests/test_client_cursor.py b/tests/test_client_cursor.py
new file mode 100644
index 0000000..b355604
--- /dev/null
+++ b/tests/test_client_cursor.py
@@ -0,0 +1,855 @@
+import pickle
+import weakref
+import datetime as dt
+from typing import List
+
+import pytest
+
+import psycopg
+from psycopg import sql, rows
+from psycopg.adapt import PyFormat
+from psycopg.postgres import types as builtins
+
+from .utils import gc_collect, gc_count
+from .test_cursor import my_row_factory
+from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision
+
+
+@pytest.fixture
+def conn(conn):
+ conn.cursor_factory = psycopg.ClientCursor
+ return conn
+
+
+def test_init(conn):
+ cur = psycopg.ClientCursor(conn)
+ cur.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ conn.row_factory = rows.dict_row
+ cur = psycopg.ClientCursor(conn)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_init_factory(conn):
+ cur = psycopg.ClientCursor(conn, row_factory=rows.dict_row)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_from_cursor_factory(conn_cls, dsn):
+ with conn_cls.connect(dsn, cursor_factory=psycopg.ClientCursor) as conn:
+ cur = conn.cursor()
+ assert type(cur) is psycopg.ClientCursor
+
+ cur.execute("select %s", (1,))
+ assert cur.fetchone() == (1,)
+ assert cur._query
+ assert cur._query.query == b"select 1"
+
+
+def test_close(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.execute("select 'foo'")
+
+ cur.close()
+ assert cur.closed
+
+
+def test_cursor_close_fetchone(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ for _ in range(5):
+ cur.fetchone()
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchone()
+
+
+def test_cursor_close_fetchmany(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchmany(2)) == 2
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchmany(2)
+
+
+def test_cursor_close_fetchall(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchall()) == 10
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchall()
+
+
+def test_context(conn):
+ with conn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+def test_weakref(conn):
+ cur = conn.cursor()
+ w = weakref.ref(cur)
+ cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+def test_pgresult(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert cur.pgresult
+ cur.close()
+ assert not cur.pgresult
+
+
+def test_statusmessage(conn):
+ cur = conn.cursor()
+ assert cur.statusmessage is None
+
+ cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+def test_execute_sql(conn):
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {value}").format(value="hello"))
+ assert cur.fetchone() == ("hello",)
+
+
+def test_execute_many_results(conn):
+ cur = conn.cursor()
+ assert cur.nextset() is None
+
+ rv = cur.execute("select %s; select generate_series(1,%s)", ("foo", 3))
+ assert rv is cur
+ assert cur.fetchall() == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.nextset() is None
+
+ cur.close()
+ assert cur.nextset() is None
+
+
+def test_execute_sequence(conn):
+ cur = conn.cursor()
+ rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+def test_execute_empty_query(conn, query):
+ cur = conn.cursor()
+ cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+
+def test_execute_type_change(conn):
+ # issue #112
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.execute(sql, (1,))
+ cur.execute(sql, (100_000,))
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+def test_executemany_type_change(conn):
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.executemany(sql, [(1,), (100_000,)])
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+def test_execute_copy(conn, query):
+ cur = conn.cursor()
+ cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute(query)
+
+
+def test_fetchone(conn):
+ cur = conn.cursor()
+ cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = cur.fetchone()
+ assert row == (1, "foo", None)
+ row = cur.fetchone()
+ assert row is None
+
+
+def test_binary_cursor_execute(conn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None])
+
+
+def test_execute_binary(conn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = conn.cursor()
+ cur.execute("select %s, %s", [1, None], binary=True)
+
+
+def test_binary_cursor_text_override(conn):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None], binary=False)
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+def test_query_encode(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("select '\u20ac'").fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+def test_query_badenc(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute("select '\u20ac'")
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop table if exists execmany;
+ create table execmany (id serial primary key, num integer, data text)
+ """
+ )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+ cur = svcconn.cursor()
+ cur.execute("truncate table execmany")
+
+
+def test_executemany(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(11, "hello"), (21, "world")]
+
+
+def test_executemany_no_data(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+def test_executemany_rowcount(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+def test_executemany_returning(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.fetchone() == (10,)
+ assert cur.nextset()
+ assert cur.fetchone() == (20,)
+ assert cur.nextset() is None
+
+
+def test_executemany_returning_discard(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ assert cur.nextset() is None
+
+
+def test_executemany_no_result(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+def test_executemany_rowcount_no_hit(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ # This fails, but only because we try to copy in pipeline mode,
+ # crashing the connection. Which would be even fine, but with
+ # the async cursor it's worse... See test_client_cursor_async.py.
+ # "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+def test_executemany_badquery(conn, query):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_executemany_null_first(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table testmany (a bigint, b bigint)")
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+def test_rowcount(conn):
+ cur = conn.cursor()
+
+ cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+
+def test_rownumber(conn):
+ cur = conn.cursor()
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+def test_iter(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ assert list(cur) == [(1,), (2,), (3,)]
+
+
+def test_iter_stop(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ for rec in cur:
+ assert rec == (1,)
+ break
+
+ for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert cur.fetchone() == (3,)
+ assert list(cur) == []
+
+
+def test_row_factory(conn):
+ cur = conn.cursor(row_factory=my_row_factory)
+
+ cur.execute("reset search_path")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+ cur.execute("select 'foo' as bar")
+ (r,) = cur.fetchone()
+ assert r == "FOObar"
+
+ cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert cur.fetchall() == [["Yy", "Zz"]]
+
+ cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert cur.fetchone() == {"y": "y", "z": "z"}
+
+
+def test_row_factory_none(conn):
+ cur = conn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ r = cur.execute("select 1 as a, 2 as b").fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+def test_bad_row_factory(conn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = conn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = conn.cursor(row_factory=broken_maker)
+ cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ cur.fetchone()
+
+
+def test_scroll(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.scroll(0)
+
+ cur.execute("select generate_series(0,9)")
+ cur.scroll(2)
+ assert cur.fetchone() == (2,)
+ cur.scroll(2)
+ assert cur.fetchone() == (5,)
+ cur.scroll(2, mode="relative")
+ assert cur.fetchone() == (8,)
+ cur.scroll(-1)
+ assert cur.fetchone() == (8,)
+ cur.scroll(-2)
+ assert cur.fetchone() == (7,)
+ cur.scroll(2, mode="absolute")
+ assert cur.fetchone() == (2,)
+
+ # on the boundary
+ cur.scroll(0, mode="absolute")
+ assert cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ cur.scroll(-1, mode="absolute")
+
+ cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(-1)
+
+ cur.scroll(9, mode="absolute")
+ assert cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ cur.scroll(10, mode="absolute")
+
+ cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ cur.scroll(1, "wat")
+
+
+def test_query_params_execute(conn):
+ cur = conn.cursor()
+ assert cur._query is None
+
+ cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select 1, NULL::text"
+ assert cur._query.params == (b"1", b"NULL")
+
+ cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select 'wat'::int"
+ assert cur._query.params == (b"'wat'",)
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select %(x)s", {"x": 1}, (1,)),
+ ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)),
+ ("select %(x)s, %(x)s", {"x": 1}, (1, 1)),
+ ],
+)
+def test_query_params_named(conn, query, params, want):
+ cur = conn.cursor()
+ cur.execute(query, params)
+ rec = cur.fetchone()
+ assert rec == want
+
+
+def test_query_params_executemany(conn):
+ cur = conn.cursor()
+
+ cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select 3, 4"
+ assert cur._query.params == (b"3", b"4")
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_stream(conn):
+ cur = conn.cursor()
+ recs = []
+ for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+class TestColumn:
+ def test_description_attribs(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ assert len(curs.description) == 3
+ for c in curs.description:
+ len(c) == 7 # DBAPI happy
+ for i, a in enumerate(
+ """
+ name type_code display_size internal_size precision scale null_ok
+ """.split()
+ ):
+ assert c[i] == getattr(c, a)
+
+ # Won't fill them up
+ assert c.null_ok is None
+
+ c = curs.description[0]
+ assert c.name == "pi"
+ assert c.type_code == builtins["numeric"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision == 10
+ assert c.scale == 2
+
+ c = curs.description[1]
+ assert c.name == "hi"
+ assert c.type_code == builtins["text"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision is None
+ assert c.scale is None
+
+ c = curs.description[2]
+ assert c.name == "now"
+ assert c.type_code == builtins["date"].oid
+ assert c.display_size is None
+ if is_crdb(conn):
+ assert c.internal_size == 16
+ else:
+ assert c.internal_size == 4
+ assert c.precision is None
+ assert c.scale is None
+
+ def test_description_slice(self, conn):
+ curs = conn.cursor()
+ curs.execute("select 1::int as a")
+ curs.description[0][0:2] == ("a", 23)
+
+ @pytest.mark.parametrize(
+ "type, precision, scale, dsize, isize",
+ [
+ ("text", None, None, None, None),
+ ("varchar", None, None, None, None),
+ ("varchar(42)", None, None, 42, None),
+ ("int4", None, None, None, 4),
+ ("numeric", None, None, None, None),
+ ("numeric(10)", 10, 0, None, None),
+ ("numeric(10, 3)", 10, 3, None, None),
+ ("time", None, None, None, 8),
+ crdb_time_precision("time(4)", 4, None, None, 8),
+ crdb_time_precision("time(10)", 6, None, None, 8),
+ ],
+ )
+ def test_details(self, conn, type, precision, scale, dsize, isize):
+ cur = conn.cursor()
+ cur.execute(f"select null::{type}")
+ col = cur.description[0]
+ repr(col)
+ assert col.precision == precision
+ assert col.scale == scale
+ assert col.display_size == dsize
+ assert col.internal_size == isize
+
+ def test_pickle(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ description = curs.description
+ pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
+ unpickled = pickle.loads(pickled)
+ assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]
+
+ @pytest.mark.crdb_skip("no col query")
+ def test_no_col_query(self, conn):
+ cur = conn.execute("select")
+ assert cur.description == []
+ assert cur.fetchall() == [()]
+
+ def test_description_closed_connection(self, conn):
+ # If we have reasons to break this test we will (e.g. we really need
+ # the connection). In #172 it fails just by accident.
+ cur = conn.execute("select 1::int4 as foo")
+ conn.close()
+ assert len(cur.description) == 1
+ col = cur.description[0]
+ assert col.name == "foo"
+ assert col.type_code == 23
+
+ def test_name_not_a_name(self, conn):
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "foo-bar"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_name_encode(self, conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "\u20ac"
+
+
+def test_str(conn):
+ cur = conn.cursor()
+ assert "psycopg.ClientCursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_leak(conn_cls, dsn, faker, fetch, row_factory):
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ def work():
+ with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True):
+ with psycopg.ClientCursor(conn, row_factory=row_factory) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ cur.fetchall()
+ elif fetch == "iter":
+ for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select 'hello'", (), "select 'hello'"),
+ ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"),
+ ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"),
+ ("select %%", (), "select %%"),
+ ("select %%, %s", (["a"],), "select %, 'a'"),
+ ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"),
+ ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"),
+ ],
+)
+def test_mogrify(conn, query, params, want):
+ cur = conn.cursor()
+ got = cur.mogrify(query, *params)
+ assert got == want
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+def test_mogrify_encoding(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ q = conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+ assert q == "select '\u20ac'"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+def test_mogrify_badenc(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ with pytest.raises(UnicodeEncodeError):
+ conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+
+
+@pytest.mark.pipeline
+def test_message_0x33(conn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ conn.autocommit = True
+ with conn.pipeline():
+ cur = conn.execute("select 'test'")
+ assert cur.fetchone() == ("test",)
+
+ assert not notices
diff --git a/tests/test_client_cursor_async.py b/tests/test_client_cursor_async.py
new file mode 100644
index 0000000..0cf8ec6
--- /dev/null
+++ b/tests/test_client_cursor_async.py
@@ -0,0 +1,727 @@
+import pytest
+import weakref
+import datetime as dt
+from typing import List
+
+import psycopg
+from psycopg import sql, rows
+from psycopg.adapt import PyFormat
+
+from .utils import alist, gc_collect, gc_count
+from .test_cursor import my_row_factory
+from .test_cursor import execmany, _execmany # noqa: F401
+from .fix_crdb import crdb_encoding
+
+execmany = execmany # avoid F811 underneath
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+async def aconn(aconn):
+ aconn.cursor_factory = psycopg.AsyncClientCursor
+ return aconn
+
+
+async def test_init(aconn):
+ cur = psycopg.AsyncClientCursor(aconn)
+ await cur.execute("select 1")
+ assert (await cur.fetchone()) == (1,)
+
+ aconn.row_factory = rows.dict_row
+ cur = psycopg.AsyncClientCursor(aconn)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_init_factory(aconn):
+ cur = psycopg.AsyncClientCursor(aconn, row_factory=rows.dict_row)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_from_cursor_factory(aconn_cls, dsn):
+ async with await aconn_cls.connect(
+ dsn, cursor_factory=psycopg.AsyncClientCursor
+ ) as aconn:
+ cur = aconn.cursor()
+ assert type(cur) is psycopg.AsyncClientCursor
+
+ await cur.execute("select %s", (1,))
+ assert await cur.fetchone() == (1,)
+ assert cur._query
+ assert cur._query.query == b"select 1"
+
+
+async def test_close(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.execute("select 'foo'")
+
+ await cur.close()
+ assert cur.closed
+
+
+async def test_cursor_close_fetchone(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ for _ in range(5):
+ await cur.fetchone()
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchone()
+
+
+async def test_cursor_close_fetchmany(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchmany(2)) == 2
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchmany(2)
+
+
+async def test_cursor_close_fetchall(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchall()) == 10
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchall()
+
+
+async def test_context(aconn):
+ async with aconn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+async def test_weakref(aconn):
+ cur = aconn.cursor()
+ w = weakref.ref(cur)
+ await cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+async def test_pgresult(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert cur.pgresult
+ await cur.close()
+ assert not cur.pgresult
+
+
+async def test_statusmessage(aconn):
+ cur = aconn.cursor()
+ assert cur.statusmessage is None
+
+ await cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ await cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+async def test_execute_sql(aconn):
+ cur = aconn.cursor()
+ await cur.execute(sql.SQL("select {value}").format(value="hello"))
+ assert await cur.fetchone() == ("hello",)
+
+
+async def test_execute_many_results(aconn):
+ cur = aconn.cursor()
+ assert cur.nextset() is None
+
+ rv = await cur.execute("select %s; select generate_series(1,%s)", ("foo", 3))
+ assert rv is cur
+ assert (await cur.fetchall()) == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert (await cur.fetchall()) == [(1,), (2,), (3,)]
+ assert cur.rowcount == 3
+ assert cur.nextset() is None
+
+ await cur.close()
+ assert cur.nextset() is None
+
+
+async def test_execute_sequence(aconn):
+ cur = aconn.cursor()
+ rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+async def test_execute_empty_query(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+
+
+async def test_execute_type_change(aconn):
+ # issue #112
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.execute(sql, (1,))
+ await cur.execute(sql, (100_000,))
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+async def test_executemany_type_change(aconn):
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.executemany(sql, [(1,), (100_000,)])
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+async def test_execute_copy(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute(query)
+
+
+async def test_fetchone(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = await cur.fetchone()
+ assert row == (1, "foo", None)
+ row = await cur.fetchone()
+ assert row is None
+
+
+async def test_binary_cursor_execute(aconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None])
+
+
+async def test_execute_binary(aconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = aconn.cursor()
+ await cur.execute("select %s, %s", [1, None], binary=True)
+
+
+async def test_binary_cursor_text_override(aconn):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None], binary=False)
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+async def test_query_encode(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ await cur.execute("select '\u20ac'")
+ (res,) = await cur.fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+async def test_query_badenc(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ await cur.execute("select '\u20ac'")
+
+
+async def test_executemany(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(10, "hello"), (20, "world")]
+
+
+async def test_executemany_name(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(11, "hello"), (21, "world")]
+
+
+async def test_executemany_no_data(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+async def test_executemany_rowcount(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+async def test_executemany_returning(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert (await cur.fetchone()) == (10,)
+ assert cur.nextset()
+ assert (await cur.fetchone()) == (20,)
+ assert cur.nextset() is None
+
+
+async def test_executemany_returning_discard(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ assert cur.nextset() is None
+
+
+async def test_executemany_no_result(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+async def test_executemany_rowcount_no_hit(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ await cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ await cur.executemany(
+ "delete from execmany where id = %s returning num", [(-1,), (-2,)]
+ )
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ # This fails because we end up trying to copy in pipeline mode.
+ # However, sometimes (and pretty regularly if we enable pgconn.trace())
+ # something goes in a loop and only terminates by OOM. Strace shows
+ # an allocation loop. I think it's in the libpq.
+ # "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+async def test_executemany_badquery(aconn, query):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+async def test_executemany_null_first(aconn, fmt_in):
+ cur = aconn.cursor()
+ await cur.execute("create table testmany (a bigint, b bigint)")
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+async def test_rowcount(aconn):
+ cur = aconn.cursor()
+
+ await cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ await cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ await cur.execute(
+ "insert into test_rowcount_notuples select generate_series(1, 42)"
+ )
+ assert cur.rowcount == 42
+
+
+async def test_rownumber(aconn):
+ cur = aconn.cursor()
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ async for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(await cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+async def test_iter(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ res = []
+ async for rec in cur:
+ res.append(rec)
+ assert res == [(1,), (2,), (3,)]
+
+
+async def test_iter_stop(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ async for rec in cur:
+ assert rec == (1,)
+ break
+
+ async for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert (await cur.fetchone()) == (3,)
+ async for rec in cur:
+ assert False
+
+
+async def test_row_factory(aconn):
+ cur = aconn.cursor(row_factory=my_row_factory)
+ await cur.execute("select 'foo' as bar")
+ (r,) = await cur.fetchone()
+ assert r == "FOObar"
+
+ await cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert await cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert await cur.fetchall() == [["Yy", "Zz"]]
+
+ await cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert await cur.fetchone() == {"y": "y", "z": "z"}
+
+
+async def test_row_factory_none(aconn):
+ cur = aconn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ await cur.execute("select 1 as a, 2 as b")
+ r = await cur.fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+async def test_bad_row_factory(aconn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = aconn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ await cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = aconn.cursor(row_factory=broken_maker)
+ await cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ await cur.fetchone()
+
+
+async def test_scroll(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.scroll(0)
+
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-1)
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-2)
+ assert await cur.fetchone() == (7,)
+ await cur.scroll(2, mode="absolute")
+ assert await cur.fetchone() == (2,)
+
+ # on the boundary
+ await cur.scroll(0, mode="absolute")
+ assert await cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ await cur.scroll(-1, mode="absolute")
+
+ await cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(-1)
+
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ await cur.scroll(10, mode="absolute")
+
+ await cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(1, "wat")
+
+
+async def test_query_params_execute(aconn):
+ cur = aconn.cursor()
+ assert cur._query is None
+
+ await cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select 1, NULL::text"
+ assert cur._query.params == (b"1", b"NULL")
+
+ await cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ await cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select 'wat'::int"
+ assert cur._query.params == (b"'wat'",)
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select %(x)s", {"x": 1}, (1,)),
+ ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)),
+ ("select %(x)s, %(x)s", {"x": 1}, (1, 1)),
+ ],
+)
+async def test_query_params_named(aconn, query, params, want):
+ cur = aconn.cursor()
+ await cur.execute(query, params)
+ rec = await cur.fetchone()
+ assert rec == want
+
+
+async def test_query_params_executemany(aconn):
+ cur = aconn.cursor()
+
+ await cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select 3, 4"
+ assert cur._query.params == (b"3", b"4")
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_stream(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+async def test_str(aconn):
+ cur = aconn.cursor()
+ assert "psycopg.AsyncClientCursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ await cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ await cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_leak(aconn_cls, dsn, faker, fetch, row_factory):
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn, conn.transaction(
+ force_rollback=True
+ ):
+ async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+ async with faker.find_insert_problem_async(conn):
+ await cur.executemany(faker.insert_stmt, faker.records)
+ await cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = await cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = await cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ await cur.fetchall()
+ elif fetch == "iter":
+ async for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select 'hello'", (), "select 'hello'"),
+ ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"),
+ ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"),
+ ("select %%", (), "select %%"),
+ ("select %%, %s", (["a"],), "select %, 'a'"),
+ ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"),
+ ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"),
+ ],
+)
+async def test_mogrify(aconn, query, params, want):
+ cur = aconn.cursor()
+ got = cur.mogrify(query, *params)
+ assert got == want
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+async def test_mogrify_encoding(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ q = aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+ assert q == "select '\u20ac'"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+async def test_mogrify_badenc(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ with pytest.raises(UnicodeEncodeError):
+ aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+
+
+@pytest.mark.pipeline
+async def test_message_0x33(aconn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ cur = await aconn.execute("select 'test'")
+ assert (await cur.fetchone()) == ("test",)
+
+ assert not notices
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
new file mode 100644
index 0000000..eec24f1
--- /dev/null
+++ b/tests/test_concurrency.py
@@ -0,0 +1,327 @@
+"""
+Tests dealing with concurrency issues.
+"""
+
+import os
+import sys
+import time
+import queue
+import signal
+import threading
+import multiprocessing
+import subprocess as sp
+from typing import List
+
+import pytest
+
+import psycopg
+from psycopg import errors as e
+
+
+@pytest.mark.slow
+def test_concurrent_execution(conn_cls, dsn):
+ def worker():
+ cnn = conn_cls.connect(dsn)
+ cur = cnn.cursor()
+ cur.execute("select pg_sleep(0.5)")
+ cur.close()
+ cnn.close()
+
+ t1 = threading.Thread(target=worker)
+ t2 = threading.Thread(target=worker)
+ t0 = time.time()
+ t1.start()
+ t2.start()
+ t1.join()
+ t2.join()
+ assert time.time() - t0 < 0.8, "something broken in concurrency"
+
+
+@pytest.mark.slow
+def test_commit_concurrency(conn):
+ # Check the condition reported in psycopg2#103
+ # Because of bad status check, we commit even when a commit is already on
+ # its way. We can detect this condition by the warnings.
+ notices = queue.Queue() # type: ignore[var-annotated]
+ conn.add_notice_handler(lambda diag: notices.put(diag.message_primary))
+ stop = False
+
+ def committer():
+ nonlocal stop
+ while not stop:
+ conn.commit()
+
+ cur = conn.cursor()
+ t1 = threading.Thread(target=committer)
+ t1.start()
+ for i in range(1000):
+ cur.execute("select %s;", (i,))
+ conn.commit()
+
+ # Stop the committer thread
+ stop = True
+ t1.join()
+
+ assert notices.empty(), "%d notices raised" % notices.qsize()
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+def test_multiprocess_close(dsn, tmpdir):
+ # Check the problem reported in psycopg2#829
+ # Subprocess gcs the copy of the fd after fork so it closes connection.
+ module = f"""\
+import time
+import psycopg
+
+def thread():
+ conn = psycopg.connect({dsn!r})
+ curs = conn.cursor()
+ for i in range(10):
+ curs.execute("select 1")
+ time.sleep(0.1)
+
+def process():
+ time.sleep(0.2)
+"""
+
+ script = """\
+import time
+import threading
+import multiprocessing
+import mptest
+
+t = threading.Thread(target=mptest.thread, name='mythread')
+t.start()
+time.sleep(0.2)
+multiprocessing.Process(target=mptest.process, name='myprocess').start()
+t.join()
+"""
+
+ with (tmpdir / "mptest.py").open("w") as f:
+ f.write(module)
+ env = dict(os.environ)
+ env["PYTHONPATH"] = str(tmpdir + os.pathsep + env.get("PYTHONPATH", ""))
+ out = sp.check_output(
+ [sys.executable, "-c", script], stderr=sp.STDOUT, env=env
+ ).decode("utf8", "replace")
+ assert out == "", out.strip().splitlines()[-1]
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("notify")
+def test_notifies(conn_cls, conn, dsn):
+ nconn = conn_cls.connect(dsn, autocommit=True)
+ npid = nconn.pgconn.backend_pid
+
+ def notifier():
+ time.sleep(0.25)
+ nconn.cursor().execute("notify foo, '1'")
+ time.sleep(0.25)
+ nconn.cursor().execute("notify foo, '2'")
+ nconn.close()
+
+ conn.autocommit = True
+ conn.cursor().execute("listen foo")
+
+ t0 = time.time()
+ t = threading.Thread(target=notifier)
+ t.start()
+
+ ns = []
+ gen = conn.notifies()
+ for n in gen:
+ ns.append((n, time.time()))
+ if len(ns) >= 2:
+ gen.close()
+
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert isinstance(n, psycopg.Notify)
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+ t.join()
+
+
+def canceller(conn, errors):
+ try:
+ time.sleep(0.5)
+ conn.cancel()
+ except Exception as exc:
+ errors.append(exc)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+def test_cancel(conn):
+ errors: List[Exception] = []
+
+ cur = conn.cursor()
+ t = threading.Thread(target=canceller, args=(conn, errors))
+ t0 = time.time()
+ t.start()
+
+ with pytest.raises(e.QueryCanceled):
+ cur.execute("select pg_sleep(2)")
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ conn.rollback()
+ assert cur.execute("select 1").fetchone()[0] == 1
+
+ t.join()
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+def test_cancel_stream(conn):
+ errors: List[Exception] = []
+
+ cur = conn.cursor()
+ t = threading.Thread(target=canceller, args=(conn, errors))
+ t0 = time.time()
+ t.start()
+
+ with pytest.raises(e.QueryCanceled):
+ for row in cur.stream("select pg_sleep(2)"):
+ pass
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ conn.rollback()
+ assert cur.execute("select 1").fetchone()[0] == 1
+
+ t.join()
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+@pytest.mark.slow
+def test_identify_closure(conn_cls, dsn):
+ def closer():
+ time.sleep(0.2)
+ conn2.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
+
+ conn = conn_cls.connect(dsn)
+ conn2 = conn_cls.connect(dsn)
+ try:
+ t = threading.Thread(target=closer)
+ t.start()
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("select pg_sleep(1.0)")
+ t1 = time.time()
+ assert 0.2 < t1 - t0 < 0.4
+ finally:
+ t.join()
+ finally:
+ conn.close()
+ conn2.close()
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
+)
+@pytest.mark.crdb_skip("cancel")
+def test_ctrl_c(dsn):
+ if sys.platform == "win32":
+ sig = int(signal.CTRL_C_EVENT)
+ # Or pytest will receive the Ctrl-C too
+ creationflags = sp.CREATE_NEW_PROCESS_GROUP
+ else:
+ sig = int(signal.SIGINT)
+ creationflags = 0
+
+ script = f"""\
+import os
+import time
+import psycopg
+from threading import Thread
+
+def tired_of_life():
+ time.sleep(1)
+ os.kill(os.getpid(), {sig!r})
+
+t = Thread(target=tired_of_life, daemon=True)
+t.start()
+
+with psycopg.connect({dsn!r}) as conn:
+ cur = conn.cursor()
+ ctrl_c = False
+ try:
+ cur.execute("select pg_sleep(2)")
+ except KeyboardInterrupt:
+ ctrl_c = True
+
+ assert ctrl_c, "ctrl-c not received"
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ conn.rollback()
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ cur.execute("select 1")
+ assert cur.fetchone() == (1,)
+"""
+ t0 = time.time()
+ proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags)
+ proc.communicate()
+ t = time.time() - t0
+ assert proc.returncode == 0
+ assert 1 < t < 2
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ multiprocessing.get_all_start_methods()[0] != "fork",
+ reason="problematic behavior only exhibited via fork",
+)
+def test_segfault_on_fork_close(dsn):
+ # https://github.com/psycopg/psycopg/issues/300
+ script = f"""\
+import gc
+import psycopg
+from multiprocessing import Pool
+
+def test(arg):
+ conn1 = psycopg.connect({dsn!r})
+ conn1.close()
+ conn1 = None
+ gc.collect()
+ return 1
+
+if __name__ == '__main__':
+ conn = psycopg.connect({dsn!r})
+ with Pool(2) as p:
+ pool_result = p.map_async(test, [1, 2])
+ pool_result.wait(timeout=5)
+ if pool_result.ready():
+ print(pool_result.get(timeout=1))
+"""
+ env = dict(os.environ)
+ env["PYTHONFAULTHANDLER"] = "1"
+ out = sp.check_output([sys.executable, "-s", "-c", script], env=env)
+ assert out.decode().rstrip() == "[1, 1]"
diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py
new file mode 100644
index 0000000..29b08cf
--- /dev/null
+++ b/tests/test_concurrency_async.py
@@ -0,0 +1,242 @@
+import sys
+import time
+import signal
+import asyncio
+import subprocess as sp
+from asyncio.queues import Queue
+from typing import List, Tuple
+
+import pytest
+
+import psycopg
+from psycopg import errors as e
+from psycopg._compat import create_task
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.slow
+async def test_commit_concurrency(aconn):
+ # Check the condition reported in psycopg2#103
+ # Because of bad status check, we commit even when a commit is already on
+ # its way. We can detect this condition by the warnings.
+ notices = Queue() # type: ignore[var-annotated]
+ aconn.add_notice_handler(lambda diag: notices.put_nowait(diag.message_primary))
+ stop = False
+
+ async def committer():
+ nonlocal stop
+ while not stop:
+ await aconn.commit()
+ await asyncio.sleep(0) # Allow the other worker to work
+
+ async def runner():
+ nonlocal stop
+ cur = aconn.cursor()
+ for i in range(1000):
+ await cur.execute("select %s;", (i,))
+ await aconn.commit()
+
+ # Stop the committer thread
+ stop = True
+
+ await asyncio.gather(committer(), runner())
+ assert notices.empty(), "%d notices raised" % notices.qsize()
+
+
+@pytest.mark.slow
+async def test_concurrent_execution(aconn_cls, dsn):
+ async def worker():
+ cnn = await aconn_cls.connect(dsn)
+ cur = cnn.cursor()
+ await cur.execute("select pg_sleep(0.5)")
+ await cur.close()
+ await cnn.close()
+
+ workers = [worker(), worker()]
+ t0 = time.time()
+ await asyncio.gather(*workers)
+ assert time.time() - t0 < 0.8, "something broken in concurrency"
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("notify")
+async def test_notifies(aconn_cls, aconn, dsn):
+ nconn = await aconn_cls.connect(dsn, autocommit=True)
+ npid = nconn.pgconn.backend_pid
+
+ async def notifier():
+ cur = nconn.cursor()
+ await asyncio.sleep(0.25)
+ await cur.execute("notify foo, '1'")
+ await asyncio.sleep(0.25)
+ await cur.execute("notify foo, '2'")
+ await nconn.close()
+
+ async def receiver():
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("listen foo")
+ gen = aconn.notifies()
+ async for n in gen:
+ ns.append((n, time.time()))
+ if len(ns) >= 2:
+ await gen.aclose()
+
+ ns: List[Tuple[psycopg.Notify, float]] = []
+ t0 = time.time()
+ workers = [notifier(), receiver()]
+ await asyncio.gather(*workers)
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+async def canceller(aconn, errors):
+ try:
+ await asyncio.sleep(0.5)
+ aconn.cancel()
+ except Exception as exc:
+ errors.append(exc)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+async def test_cancel(aconn):
+ async def worker():
+ cur = aconn.cursor()
+ with pytest.raises(e.QueryCanceled):
+ await cur.execute("select pg_sleep(2)")
+
+ errors: List[Exception] = []
+ workers = [worker(), canceller(aconn, errors)]
+
+ t0 = time.time()
+ await asyncio.gather(*workers)
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ await aconn.rollback()
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+async def test_cancel_stream(aconn):
+ async def worker():
+ cur = aconn.cursor()
+ with pytest.raises(e.QueryCanceled):
+ async for row in cur.stream("select pg_sleep(2)"):
+ pass
+
+ errors: List[Exception] = []
+ workers = [worker(), canceller(aconn, errors)]
+
+ t0 = time.time()
+ await asyncio.gather(*workers)
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ await aconn.rollback()
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_identify_closure(aconn_cls, dsn):
+ async def closer():
+ await asyncio.sleep(0.2)
+ await conn2.execute(
+ "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid]
+ )
+
+ aconn = await aconn_cls.connect(dsn)
+ conn2 = await aconn_cls.connect(dsn)
+ try:
+ t = create_task(closer())
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute("select pg_sleep(1.0)")
+ t1 = time.time()
+ assert 0.2 < t1 - t0 < 0.4
+ finally:
+ await asyncio.gather(t)
+ finally:
+ await aconn.close()
+ await conn2.close()
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
+)
+@pytest.mark.crdb_skip("cancel")
+async def test_ctrl_c(dsn):
+ script = f"""\
+import signal
+import asyncio
+import psycopg
+
+async def main():
+ ctrl_c = False
+ loop = asyncio.get_event_loop()
+ async with await psycopg.AsyncConnection.connect({dsn!r}) as conn:
+ loop.add_signal_handler(signal.SIGINT, conn.cancel)
+ cur = conn.cursor()
+ try:
+ await cur.execute("select pg_sleep(2)")
+ except psycopg.errors.QueryCanceled:
+ ctrl_c = True
+
+ assert ctrl_c, "ctrl-c not received"
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ await conn.rollback()
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ await cur.execute("select 1")
+ assert (await cur.fetchone()) == (1,)
+
+asyncio.run(main())
+"""
+ if sys.platform == "win32":
+ creationflags = sp.CREATE_NEW_PROCESS_GROUP
+ sig = signal.CTRL_C_EVENT
+ else:
+ creationflags = 0
+ sig = signal.SIGINT
+
+ proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags)
+ with pytest.raises(sp.TimeoutExpired):
+ outs, errs = proc.communicate(timeout=1)
+
+ proc.send_signal(sig)
+ proc.communicate()
+ assert proc.returncode == 0
diff --git a/tests/test_connection.py b/tests/test_connection.py
new file mode 100644
index 0000000..57c6c78
--- /dev/null
+++ b/tests/test_connection.py
@@ -0,0 +1,790 @@
+import time
+import pytest
+import logging
+import weakref
+from typing import Any, List
+from dataclasses import dataclass
+
+import psycopg
+from psycopg import Notify, errors as e
+from psycopg.rows import tuple_row
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
+
+from .utils import gc_collect
+from .test_cursor import my_row_factory
+from .test_adapt import make_bin_dumper, make_dumper
+
+
+def test_connect(conn_cls, dsn):
+ conn = conn_cls.connect(dsn)
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ conn.close()
+
+
+def test_connect_str_subclass(conn_cls, dsn):
+ class MyString(str):
+ pass
+
+ conn = conn_cls.connect(MyString(dsn))
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ conn.close()
+
+
+def test_connect_bad(conn_cls):
+ with pytest.raises(psycopg.OperationalError):
+ conn_cls.connect("dbname=nosuchdb")
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_connect_timeout(conn_cls, deaf_port):
+ t0 = time.time()
+ with pytest.raises(psycopg.OperationalError, match="timeout expired"):
+ conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+ elapsed = time.time() - t0
+ assert elapsed == pytest.approx(1.0, abs=0.05)
+
+
+def test_close(conn):
+ assert not conn.closed
+ assert not conn.broken
+
+ cur = conn.cursor()
+
+ conn.close()
+ assert conn.closed
+ assert not conn.broken
+ assert conn.pgconn.status == conn.ConnStatus.BAD
+
+ conn.close()
+ assert conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.BAD
+
+ with pytest.raises(psycopg.OperationalError):
+ cur.execute("select 1")
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_broken(conn):
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
+ assert conn.closed
+ assert conn.broken
+ conn.close()
+ assert conn.closed
+ assert conn.broken
+
+
+def test_cursor_closed(conn):
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ with conn.cursor("foo"):
+ pass
+ with pytest.raises(psycopg.OperationalError):
+ conn.cursor()
+
+
+def test_connection_warn_close(conn_cls, dsn, recwarn):
+ conn = conn_cls.connect(dsn)
+ conn.close()
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+ conn = conn_cls.connect(dsn)
+ del conn
+ assert "IDLE" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = conn_cls.connect(dsn)
+ conn.execute("select 1")
+ del conn
+ assert "INTRANS" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = conn_cls.connect(dsn)
+ try:
+ conn.execute("select wat")
+ except Exception:
+ pass
+ del conn
+ assert "INERROR" in str(recwarn.pop(ResourceWarning).message)
+
+ with conn_cls.connect(dsn) as conn:
+ pass
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+@pytest.fixture
+def testctx(svcconn):
+ svcconn.execute("create table if not exists testctx (id int primary key)")
+ svcconn.execute("delete from testctx")
+ return None
+
+
+def test_context_commit(conn_cls, testctx, conn, dsn):
+ with conn:
+ with conn.cursor() as cur:
+ cur.execute("insert into testctx values (42)")
+
+ assert conn.closed
+ assert not conn.broken
+
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor() as cur:
+ cur.execute("select * from testctx")
+ assert cur.fetchall() == [(42,)]
+
+
+def test_context_rollback(conn_cls, testctx, conn, dsn):
+ with pytest.raises(ZeroDivisionError):
+ with conn:
+ with conn.cursor() as cur:
+ cur.execute("insert into testctx values (42)")
+ 1 / 0
+
+ assert conn.closed
+ assert not conn.broken
+
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor() as cur:
+ cur.execute("select * from testctx")
+ assert cur.fetchall() == []
+
+
+def test_context_close(conn):
+ with conn:
+ conn.execute("select 1")
+ conn.close()
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_context_inerror_rollback_no_clobber(conn_cls, conn, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ with conn_cls.connect(dsn) as conn2:
+ conn2.execute("select 1")
+ conn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ with conn_cls.connect(dsn) as conn:
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ assert not conn.pgconn.error_message
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.slow
+def test_weakref(conn_cls, dsn):
+ conn = conn_cls.connect(dsn)
+ w = weakref.ref(conn)
+ conn.close()
+ del conn
+ gc_collect()
+ assert w() is None
+
+
+def test_commit(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+ conn.pgconn.exec_(b"begin")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+ conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ res = conn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) == b"1"
+
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.commit()
+
+
+@pytest.mark.crdb_skip("deferrable")
+def test_commit_error(conn):
+ conn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ conn.commit()
+
+ conn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+
+def test_rollback(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+ conn.pgconn.exec_(b"begin")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+ conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.rollback()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ res = conn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.ntuples == 0
+
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.rollback()
+
+
+def test_auto_transaction(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = conn.cursor()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ cur.execute("insert into foo values (1)")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ assert cur.execute("select * from foo").fetchone() == (1,)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_auto_transaction_fail(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = conn.cursor()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ cur.execute("insert into foo values (1)")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("meh")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+ with pytest.raises(psycopg.errors.InFailedSqlTransaction):
+ cur.execute("select 1")
+
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ assert cur.execute("select * from foo").fetchone() is None
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_autocommit(conn):
+ assert conn.autocommit is False
+ conn.autocommit = True
+ assert conn.autocommit
+ cur = conn.cursor()
+ assert cur.execute("select 1").fetchone() == (1,)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.autocommit = ""
+ assert conn.autocommit is False # type: ignore[comparison-overlap]
+ conn.autocommit = "yeah"
+ assert conn.autocommit is True
+
+
+def test_autocommit_connect(conn_cls, dsn):
+ conn = conn_cls.connect(dsn, autocommit=True)
+ assert conn.autocommit
+ conn.close()
+
+
+def test_autocommit_intrans(conn):
+ cur = conn.cursor()
+ assert cur.execute("select 1").fetchone() == (1,)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.autocommit = True
+ assert not conn.autocommit
+
+
+def test_autocommit_inerror(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("meh")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.autocommit = True
+ assert not conn.autocommit
+
+
+def test_autocommit_unknown(conn):
+ conn.close()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.UNKNOWN
+ with pytest.raises(psycopg.OperationalError):
+ conn.autocommit = True
+ assert not conn.autocommit
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want",
+ [
+ ((), {}, ""),
+ (("",), {}, ""),
+ (("host=foo user=bar",), {}, "host=foo user=bar"),
+ (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+ (
+ ("host=foo port=5432",),
+ {"host": "qux", "user": "joe"},
+ "host=qux user=joe port=5432",
+ ),
+ (("host=foo",), {"user": None}, "host=foo"),
+ ],
+)
+def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want):
+ the_conninfo: str
+
+ def fake_connect(conninfo):
+ nonlocal the_conninfo
+ the_conninfo = conninfo
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ conn = conn_cls.connect(*args, **kwargs)
+ assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ conn.close()
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, exctype",
+ [
+ (("host=foo", "host=bar"), {}, TypeError),
+ (("", ""), {}, TypeError),
+ ((), {"nosuchparam": 42}, psycopg.ProgrammingError),
+ ],
+)
+def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype):
+ def fake_connect(conninfo):
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ with pytest.raises(exctype):
+ conn_cls.connect(*args, **kwargs)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_broken_connection(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("select pg_terminate_backend(pg_backend_pid())")
+ assert conn.closed
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice_handlers(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ messages = []
+ severities = []
+
+ def cb1(diag):
+ messages.append(diag.message_primary)
+
+ def cb2(res):
+ raise Exception("hello from cb2")
+
+ conn.add_notice_handler(cb1)
+ conn.add_notice_handler(cb2)
+ conn.add_notice_handler("the wrong thing")
+ conn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
+
+ conn.pgconn.exec_(b"set client_min_messages to notice")
+ cur = conn.cursor()
+ cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE"]
+
+ assert len(caplog.records) == 2
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello from cb2" in rec.message
+ rec = caplog.records[1]
+ assert rec.levelno == logging.ERROR
+ assert "the wrong thing" in rec.message
+
+ conn.remove_notice_handler(cb1)
+ conn.remove_notice_handler("the wrong thing")
+ cur.execute("do $$begin raise warning 'hello warning'; end$$ language plpgsql")
+ assert len(caplog.records) == 3
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE", "WARNING"]
+
+ with pytest.raises(ValueError):
+ conn.remove_notice_handler(cb1)
+
+
+@pytest.mark.crdb_skip("notify")
+def test_notify_handlers(conn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ conn.add_notify_handler(cb1)
+ conn.add_notify_handler(lambda n: nots2.append(n))
+
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("listen foo")
+ cur.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == conn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ conn.remove_notify_handler(cb1)
+ cur.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert isinstance(n, Notify)
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == conn.pgconn.backend_pid
+ assert hash(n)
+
+ with pytest.raises(ValueError):
+ conn.remove_notify_handler(cb1)
+
+
+def test_execute(conn):
+ cur = conn.execute("select %s, %s", [10, 20])
+ assert cur.fetchone() == (10, 20)
+ assert cur.format == 0
+ assert cur.pgresult.fformat(0) == 0
+
+ cur = conn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21})
+ assert cur.fetchone() == (11, 21)
+
+ cur = conn.execute("select 12, 22")
+ assert cur.fetchone() == (12, 22)
+
+
+def test_execute_binary(conn):
+ cur = conn.execute("select %s, %s", [10, 20], binary=True)
+ assert cur.fetchone() == (10, 20)
+ assert cur.format == 1
+ assert cur.pgresult.fformat(0) == 1
+
+
+def test_row_factory(conn_cls, dsn):
+ defaultconn = conn_cls.connect(dsn)
+ assert defaultconn.row_factory is tuple_row
+ defaultconn.close()
+
+ conn = conn_cls.connect(dsn, row_factory=my_row_factory)
+ assert conn.row_factory is my_row_factory
+
+ cur = conn.execute("select 'a' as ve")
+ assert cur.fetchone() == ["Ave"]
+
+ with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1:
+ cur1.execute("select 1, 1, 2")
+ assert cur1.fetchall() == [{1, 2}]
+
+ with conn.cursor(row_factory=tuple_row) as cur2:
+ cur2.execute("select 1, 1, 2")
+ assert cur2.fetchall() == [(1, 1, 2)]
+
+ # TODO: maybe fix something to get rid of 'type: ignore' below.
+ conn.row_factory = tuple_row
+ cur3 = conn.execute("select 'vale'")
+ r = cur3.fetchone()
+ assert r and r == ("vale",)
+ conn.close()
+
+
+def test_str(conn):
+ assert "[IDLE]" in str(conn)
+ conn.close()
+ assert "[BAD]" in str(conn)
+
+
+def test_fileno(conn):
+ assert conn.fileno() == conn.pgconn.socket
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.fileno()
+
+
+def test_cursor_factory(conn):
+ assert conn.cursor_factory is psycopg.Cursor
+
+ class MyCursor(psycopg.Cursor[psycopg.rows.Row]):
+ pass
+
+ conn.cursor_factory = MyCursor
+ with conn.cursor() as cur:
+ assert isinstance(cur, MyCursor)
+
+ with conn.execute("select 1") as cur:
+ assert isinstance(cur, MyCursor)
+
+
+def test_cursor_factory_connect(conn_cls, dsn):
+ class MyCursor(psycopg.Cursor[psycopg.rows.Row]):
+ pass
+
+ with conn_cls.connect(dsn, cursor_factory=MyCursor) as conn:
+ assert conn.cursor_factory is MyCursor
+ cur = conn.cursor()
+ assert type(cur) is MyCursor
+
+
+def test_server_cursor_factory(conn):
+ assert conn.server_cursor_factory is psycopg.ServerCursor
+
+ class MyServerCursor(psycopg.ServerCursor[psycopg.rows.Row]):
+ pass
+
+ conn.server_cursor_factory = MyServerCursor
+ with conn.cursor(name="n") as cur:
+ assert isinstance(cur, MyServerCursor)
+
+
+@dataclass
+class ParamDef:
+ name: str
+ guc: str
+ values: List[Any]
+
+
+param_isolation = ParamDef(
+ name="isolation_level",
+ guc="isolation",
+ values=list(psycopg.IsolationLevel),
+)
+param_read_only = ParamDef(
+ name="read_only",
+ guc="read_only",
+ values=[True, False],
+)
+param_deferrable = ParamDef(
+ name="deferrable",
+ guc="deferrable",
+ values=[True, False],
+)
+
+# Map Python values to Postgres values for the tx_params possible values
+tx_values_map = {
+ v.name.lower().replace("_", " "): v.value for v in psycopg.IsolationLevel
+}
+tx_values_map["on"] = True
+tx_values_map["off"] = False
+
+
+tx_params = [
+ param_isolation,
+ param_read_only,
+ pytest.param(param_deferrable, marks=pytest.mark.crdb_skip("deferrable")),
+]
+tx_params_isolation = [
+ pytest.param(
+ param_isolation,
+ id="isolation_level",
+ marks=pytest.mark.crdb("skip", reason="transaction isolation"),
+ ),
+ pytest.param(
+ param_read_only, id="read_only", marks=pytest.mark.crdb_skip("begin_read_only")
+ ),
+ pytest.param(
+ param_deferrable, id="deferrable", marks=pytest.mark.crdb_skip("deferrable")
+ ),
+]
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_transaction_param_default(conn, param):
+ assert getattr(conn, param.name) is None
+ current, default = conn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ ).fetchone()
+ assert current == default
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+def test_set_transaction_param_implicit(conn, param, autocommit):
+ conn.autocommit = autocommit
+ for value in param.values:
+ setattr(conn, param.name, value)
+ pgval, default = conn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ ).fetchone()
+ if autocommit:
+ assert pgval == default
+ else:
+ assert tx_values_map[pgval] == value
+ conn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+def test_set_transaction_param_block(conn, param, autocommit):
+ conn.autocommit = autocommit
+ for value in param.values:
+ setattr(conn, param.name, value)
+ with conn.transaction():
+ pgval = conn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ ).fetchone()[0]
+ assert tx_values_map[pgval] == value
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_set_transaction_param_not_intrans_implicit(conn, param):
+ conn.execute("select 1")
+ with pytest.raises(psycopg.ProgrammingError):
+ setattr(conn, param.name, param.values[0])
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_set_transaction_param_not_intrans_block(conn, param):
+ with conn.transaction():
+ with pytest.raises(psycopg.ProgrammingError):
+ setattr(conn, param.name, param.values[0])
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_set_transaction_param_not_intrans_external(conn, param):
+ conn.autocommit = True
+ conn.execute("begin")
+ with pytest.raises(psycopg.ProgrammingError):
+ setattr(conn, param.name, param.values[0])
+
+
+@pytest.mark.crdb("skip", reason="transaction isolation")
+def test_set_transaction_param_all(conn):
+ params: List[Any] = tx_params[:]
+ params[2] = params[2].values[0]
+
+ for param in params:
+ value = param.values[0]
+ setattr(conn, param.name, value)
+
+ for param in params:
+ pgval = conn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ ).fetchone()[0]
+ assert tx_values_map[pgval] == value
+
+
+def test_set_transaction_param_strange(conn):
+ for val in ("asdf", 0, 5):
+ with pytest.raises(ValueError):
+ conn.isolation_level = val
+
+ conn.isolation_level = psycopg.IsolationLevel.SERIALIZABLE.value
+ assert conn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE
+
+ conn.read_only = 1
+ assert conn.read_only is True
+
+ conn.deferrable = 0
+ assert conn.deferrable is False
+
+
+conninfo_params_timeout = [
+ (
+ "",
+ {"dbname": "mydb", "connect_timeout": None},
+ ({"dbname": "mydb"}, None),
+ ),
+ (
+ "",
+ {"dbname": "mydb", "connect_timeout": 1},
+ ({"dbname": "mydb", "connect_timeout": "1"}, 1),
+ ),
+ (
+ "dbname=postgres",
+ {},
+ ({"dbname": "postgres"}, None),
+ ),
+ (
+ "dbname=postgres connect_timeout=2",
+ {},
+ ({"dbname": "postgres", "connect_timeout": "2"}, 2),
+ ),
+ (
+ "postgresql:///postgres?connect_timeout=2",
+ {"connect_timeout": 10},
+ ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+ ),
+]
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+def test_get_connection_params(conn_cls, dsn, kwargs, exp):
+ params = conn_cls._get_connection_params(dsn, **kwargs)
+ conninfo = make_conninfo(**params)
+ assert conninfo_to_dict(conninfo) == exp[0]
+ assert params.get("connect_timeout") == exp[1]
+
+
+def test_connect_context(conn_cls, dsn):
+ ctx = psycopg.adapt.AdaptersMap(psycopg.adapters)
+ ctx.register_dumper(str, make_bin_dumper("b"))
+ ctx.register_dumper(str, make_dumper("t"))
+
+ conn = conn_cls.connect(dsn, context=ctx)
+
+ cur = conn.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellot"
+ cur = conn.execute("select %b", ["hello"])
+ assert cur.fetchone()[0] == "hellob"
+ conn.close()
+
+
+def test_connect_context_copy(conn_cls, dsn, conn):
+ conn.adapters.register_dumper(str, make_bin_dumper("b"))
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ conn2 = conn_cls.connect(dsn, context=conn)
+
+ cur = conn2.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellot"
+ cur = conn2.execute("select %b", ["hello"])
+ assert cur.fetchone()[0] == "hellob"
+ conn2.close()
+
+
+def test_cancel_closed(conn):
+ conn.close()
+ conn.cancel()
diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py
new file mode 100644
index 0000000..1288a6c
--- /dev/null
+++ b/tests/test_connection_async.py
@@ -0,0 +1,751 @@
+import time
+import pytest
+import logging
+import weakref
+from typing import List, Any
+
+import psycopg
+from psycopg import Notify, errors as e
+from psycopg.rows import tuple_row
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
+
+from .utils import gc_collect
+from .test_cursor import my_row_factory
+from .test_connection import tx_params, tx_params_isolation, tx_values_map
+from .test_connection import conninfo_params_timeout
+from .test_connection import testctx # noqa: F401 # fixture
+from .test_adapt import make_bin_dumper, make_dumper
+from .test_conninfo import fake_resolve # noqa: F401
+
+pytestmark = pytest.mark.asyncio
+
+
+async def test_connect(aconn_cls, dsn):
+ conn = await aconn_cls.connect(dsn)
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ await conn.close()
+
+
+async def test_connect_bad(aconn_cls):
+ with pytest.raises(psycopg.OperationalError):
+ await aconn_cls.connect("dbname=nosuchdb")
+
+
+async def test_connect_str_subclass(aconn_cls, dsn):
+ class MyString(str):
+ pass
+
+ conn = await aconn_cls.connect(MyString(dsn))
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ await conn.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_connect_timeout(aconn_cls, deaf_port):
+ t0 = time.time()
+ with pytest.raises(psycopg.OperationalError, match="timeout expired"):
+ await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+ elapsed = time.time() - t0
+ assert elapsed == pytest.approx(1.0, abs=0.05)
+
+
+async def test_close(aconn):
+ assert not aconn.closed
+ assert not aconn.broken
+
+ cur = aconn.cursor()
+
+ await aconn.close()
+ assert aconn.closed
+ assert not aconn.broken
+ assert aconn.pgconn.status == aconn.ConnStatus.BAD
+
+ await aconn.close()
+ assert aconn.closed
+ assert aconn.pgconn.status == aconn.ConnStatus.BAD
+
+ with pytest.raises(psycopg.OperationalError):
+ await cur.execute("select 1")
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_broken(aconn):
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute(
+ "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid]
+ )
+ assert aconn.closed
+ assert aconn.broken
+ await aconn.close()
+ assert aconn.closed
+ assert aconn.broken
+
+
+async def test_cursor_closed(aconn):
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ async with aconn.cursor("foo"):
+ pass
+ aconn.cursor("foo")
+ with pytest.raises(psycopg.OperationalError):
+ aconn.cursor()
+
+
+async def test_connection_warn_close(aconn_cls, dsn, recwarn):
+ conn = await aconn_cls.connect(dsn)
+ await conn.close()
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+ conn = await aconn_cls.connect(dsn)
+ del conn
+ assert "IDLE" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = await aconn_cls.connect(dsn)
+ await conn.execute("select 1")
+ del conn
+ assert "INTRANS" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = await aconn_cls.connect(dsn)
+ try:
+ await conn.execute("select wat")
+ except Exception:
+ pass
+ del conn
+ assert "INERROR" in str(recwarn.pop(ResourceWarning).message)
+
+ async with await aconn_cls.connect(dsn) as conn:
+ pass
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+@pytest.mark.usefixtures("testctx")
+async def test_context_commit(aconn_cls, aconn, dsn):
+ async with aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("insert into testctx values (42)")
+
+ assert aconn.closed
+ assert not aconn.broken
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("select * from testctx")
+ assert await cur.fetchall() == [(42,)]
+
+
+@pytest.mark.usefixtures("testctx")
+async def test_context_rollback(aconn_cls, aconn, dsn):
+ with pytest.raises(ZeroDivisionError):
+ async with aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("insert into testctx values (42)")
+ 1 / 0
+
+ assert aconn.closed
+ assert not aconn.broken
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("select * from testctx")
+ assert await cur.fetchall() == []
+
+
+async def test_context_close(aconn):
+ async with aconn:
+ await aconn.execute("select 1")
+ await aconn.close()
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_context_inerror_rollback_no_clobber(aconn_cls, conn, dsn, caplog):
+ with pytest.raises(ZeroDivisionError):
+ async with await aconn_cls.connect(dsn) as conn2:
+ await conn2.execute("select 1")
+ conn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ async with await aconn_cls.connect(dsn) as conn:
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ assert not conn.pgconn.error_message
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.slow
+async def test_weakref(aconn_cls, dsn):
+ conn = await aconn_cls.connect(dsn)
+ w = weakref.ref(conn)
+ await conn.close()
+ del conn
+ gc_collect()
+ assert w() is None
+
+
+async def test_commit(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+ aconn.pgconn.exec_(b"begin")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+ aconn.pgconn.exec_(b"insert into foo values (1)")
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) == b"1"
+
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.commit()
+
+
+@pytest.mark.crdb_skip("deferrable")
+async def test_commit_error(aconn):
+ await aconn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ await aconn.commit()
+
+ await aconn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ cur = await aconn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+async def test_rollback(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+ aconn.pgconn.exec_(b"begin")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+ aconn.pgconn.exec_(b"insert into foo values (1)")
+ await aconn.rollback()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.ntuples == 0
+
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.rollback()
+
+
+async def test_auto_transaction(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = aconn.cursor()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ await cur.execute("insert into foo values (1)")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ await cur.execute("select * from foo")
+ assert await cur.fetchone() == (1,)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_auto_transaction_fail(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = aconn.cursor()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ await cur.execute("insert into foo values (1)")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("meh")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ await cur.execute("select * from foo")
+ assert await cur.fetchone() is None
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_autocommit(aconn):
+ assert aconn.autocommit is False
+ with pytest.raises(AttributeError):
+ aconn.autocommit = True
+ assert not aconn.autocommit
+
+ await aconn.set_autocommit(True)
+ assert aconn.autocommit
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ await aconn.set_autocommit("")
+ assert aconn.autocommit is False
+ await aconn.set_autocommit("yeah")
+ assert aconn.autocommit is True
+
+
+async def test_autocommit_connect(aconn_cls, dsn):
+ aconn = await aconn_cls.connect(dsn, autocommit=True)
+ assert aconn.autocommit
+ await aconn.close()
+
+
+async def test_autocommit_intrans(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+ with pytest.raises(psycopg.ProgrammingError):
+ await aconn.set_autocommit(True)
+ assert not aconn.autocommit
+
+
+async def test_autocommit_inerror(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("meh")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+ with pytest.raises(psycopg.ProgrammingError):
+ await aconn.set_autocommit(True)
+ assert not aconn.autocommit
+
+
+async def test_autocommit_unknown(aconn):
+ await aconn.close()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.UNKNOWN
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.set_autocommit(True)
+ assert not aconn.autocommit
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want",
+ [
+ ((), {}, ""),
+ (("",), {}, ""),
+ (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
+ (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
+ (
+ ("dbname=foo port=5432",),
+ {"dbname": "qux", "user": "joe"},
+ "dbname=qux user=joe port=5432",
+ ),
+ (("dbname=foo",), {"user": None}, "dbname=foo"),
+ ],
+)
+async def test_connect_args(
+ aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
+):
+ the_conninfo: str
+
+ def fake_connect(conninfo):
+ nonlocal the_conninfo
+ the_conninfo = conninfo
+ return pgconn
+ yield
+
+ setpgenv({})
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ conn = await aconn_cls.connect(*args, **kwargs)
+ assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ await conn.close()
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, exctype",
+ [
+ (("host=foo", "host=bar"), {}, TypeError),
+ (("", ""), {}, TypeError),
+ ((), {"nosuchparam": 42}, psycopg.ProgrammingError),
+ ],
+)
+async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype):
+ def fake_connect(conninfo):
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ with pytest.raises(exctype):
+ await aconn_cls.connect(*args, **kwargs)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_broken_connection(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("select pg_terminate_backend(pg_backend_pid())")
+ assert aconn.closed
+
+
+@pytest.mark.crdb_skip("do")
+async def test_notice_handlers(aconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ messages = []
+ severities = []
+
+ def cb1(diag):
+ messages.append(diag.message_primary)
+
+ def cb2(res):
+ raise Exception("hello from cb2")
+
+ aconn.add_notice_handler(cb1)
+ aconn.add_notice_handler(cb2)
+ aconn.add_notice_handler("the wrong thing")
+ aconn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
+
+ aconn.pgconn.exec_(b"set client_min_messages to notice")
+ cur = aconn.cursor()
+ await cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE"]
+
+ assert len(caplog.records) == 2
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello from cb2" in rec.message
+ rec = caplog.records[1]
+ assert rec.levelno == logging.ERROR
+ assert "the wrong thing" in rec.message
+
+ aconn.remove_notice_handler(cb1)
+ aconn.remove_notice_handler("the wrong thing")
+ await cur.execute(
+ "do $$begin raise warning 'hello warning'; end$$ language plpgsql"
+ )
+ assert len(caplog.records) == 3
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE", "WARNING"]
+
+ with pytest.raises(ValueError):
+ aconn.remove_notice_handler(cb1)
+
+
+@pytest.mark.crdb_skip("notify")
+async def test_notify_handlers(aconn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ aconn.add_notify_handler(cb1)
+ aconn.add_notify_handler(lambda n: nots2.append(n))
+
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("listen foo")
+ await cur.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == aconn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ aconn.remove_notify_handler(cb1)
+ await cur.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert isinstance(n, Notify)
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == aconn.pgconn.backend_pid
+
+ with pytest.raises(ValueError):
+ aconn.remove_notify_handler(cb1)
+
+
+async def test_execute(aconn):
+ cur = await aconn.execute("select %s, %s", [10, 20])
+ assert await cur.fetchone() == (10, 20)
+ assert cur.format == 0
+ assert cur.pgresult.fformat(0) == 0
+
+ cur = await aconn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21})
+ assert await cur.fetchone() == (11, 21)
+
+ cur = await aconn.execute("select 12, 22")
+ assert await cur.fetchone() == (12, 22)
+
+
+async def test_execute_binary(aconn):
+ cur = await aconn.execute("select %s, %s", [10, 20], binary=True)
+ assert await cur.fetchone() == (10, 20)
+ assert cur.format == 1
+ assert cur.pgresult.fformat(0) == 1
+
+
+async def test_row_factory(aconn_cls, dsn):
+ defaultconn = await aconn_cls.connect(dsn)
+ assert defaultconn.row_factory is tuple_row
+ await defaultconn.close()
+
+ conn = await aconn_cls.connect(dsn, row_factory=my_row_factory)
+ assert conn.row_factory is my_row_factory
+
+ cur = await conn.execute("select 'a' as ve")
+ assert await cur.fetchone() == ["Ave"]
+
+ async with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1:
+ await cur1.execute("select 1, 1, 2")
+ assert await cur1.fetchall() == [{1, 2}]
+
+ async with conn.cursor(row_factory=tuple_row) as cur2:
+ await cur2.execute("select 1, 1, 2")
+ assert await cur2.fetchall() == [(1, 1, 2)]
+
+ # TODO: maybe fix something to get rid of 'type: ignore' below.
+ conn.row_factory = tuple_row
+ cur3 = await conn.execute("select 'vale'")
+ r = await cur3.fetchone()
+ assert r and r == ("vale",)
+ await conn.close()
+
+
+async def test_str(aconn):
+ assert "[IDLE]" in str(aconn)
+ await aconn.close()
+ assert "[BAD]" in str(aconn)
+
+
+async def test_fileno(aconn):
+ assert aconn.fileno() == aconn.pgconn.socket
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ aconn.fileno()
+
+
+async def test_cursor_factory(aconn):
+ assert aconn.cursor_factory is psycopg.AsyncCursor
+
+ class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]):
+ pass
+
+ aconn.cursor_factory = MyCursor
+ async with aconn.cursor() as cur:
+ assert isinstance(cur, MyCursor)
+
+ async with (await aconn.execute("select 1")) as cur:
+ assert isinstance(cur, MyCursor)
+
+
+async def test_cursor_factory_connect(aconn_cls, dsn):
+ class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]):
+ pass
+
+ async with await aconn_cls.connect(dsn, cursor_factory=MyCursor) as conn:
+ assert conn.cursor_factory is MyCursor
+ cur = conn.cursor()
+ assert type(cur) is MyCursor
+
+
+async def test_server_cursor_factory(aconn):
+ assert aconn.server_cursor_factory is psycopg.AsyncServerCursor
+
+ class MyServerCursor(psycopg.AsyncServerCursor[psycopg.rows.Row]):
+ pass
+
+ aconn.server_cursor_factory = MyServerCursor
+ async with aconn.cursor(name="n") as cur:
+ assert isinstance(cur, MyServerCursor)
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_transaction_param_default(aconn, param):
+ assert getattr(aconn, param.name) is None
+ cur = await aconn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ )
+ current, default = await cur.fetchone()
+ assert current == default
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_transaction_param_readonly_property(aconn, param):
+ with pytest.raises(AttributeError):
+ setattr(aconn, param.name, None)
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+async def test_set_transaction_param_implicit(aconn, param, autocommit):
+ await aconn.set_autocommit(autocommit)
+ for value in param.values:
+ await getattr(aconn, f"set_{param.name}")(value)
+ cur = await aconn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ )
+ pgval, default = await cur.fetchone()
+ if autocommit:
+ assert pgval == default
+ else:
+ assert tx_values_map[pgval] == value
+ await aconn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+async def test_set_transaction_param_block(aconn, param, autocommit):
+ await aconn.set_autocommit(autocommit)
+ for value in param.values:
+ await getattr(aconn, f"set_{param.name}")(value)
+ async with aconn.transaction():
+ cur = await aconn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ )
+ pgval = (await cur.fetchone())[0]
+ assert tx_values_map[pgval] == value
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_set_transaction_param_not_intrans_implicit(aconn, param):
+ await aconn.execute("select 1")
+ value = param.values[0]
+ with pytest.raises(psycopg.ProgrammingError):
+ await getattr(aconn, f"set_{param.name}")(value)
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_set_transaction_param_not_intrans_block(aconn, param):
+ value = param.values[0]
+ async with aconn.transaction():
+ with pytest.raises(psycopg.ProgrammingError):
+ await getattr(aconn, f"set_{param.name}")(value)
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_set_transaction_param_not_intrans_external(aconn, param):
+ value = param.values[0]
+ await aconn.set_autocommit(True)
+ await aconn.execute("begin")
+ with pytest.raises(psycopg.ProgrammingError):
+ await getattr(aconn, f"set_{param.name}")(value)
+
+
+@pytest.mark.crdb("skip", reason="transaction isolation")
+async def test_set_transaction_param_all(aconn):
+ params: List[Any] = tx_params[:]
+ params[2] = params[2].values[0]
+
+ for param in params:
+ value = param.values[0]
+ await getattr(aconn, f"set_{param.name}")(value)
+
+ for param in params:
+ cur = await aconn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ )
+ pgval = (await cur.fetchone())[0]
+ assert tx_values_map[pgval] == value
+
+
+async def test_set_transaction_param_strange(aconn):
+ for val in ("asdf", 0, 5):
+ with pytest.raises(ValueError):
+ await aconn.set_isolation_level(val)
+
+ await aconn.set_isolation_level(psycopg.IsolationLevel.SERIALIZABLE.value)
+ assert aconn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE
+
+ await aconn.set_read_only(1)
+ assert aconn.read_only is True
+
+ await aconn.set_deferrable(0)
+ assert aconn.deferrable is False
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
+ setpgenv({})
+ params = await aconn_cls._get_connection_params(dsn, **kwargs)
+ conninfo = make_conninfo(**params)
+ assert conninfo_to_dict(conninfo) == exp[0]
+ assert params["connect_timeout"] == exp[1]
+
+
+async def test_connect_context_adapters(aconn_cls, dsn):
+ ctx = psycopg.adapt.AdaptersMap(psycopg.adapters)
+ ctx.register_dumper(str, make_bin_dumper("b"))
+ ctx.register_dumper(str, make_dumper("t"))
+
+ conn = await aconn_cls.connect(dsn, context=ctx)
+
+ cur = await conn.execute("select %s", ["hello"])
+ assert (await cur.fetchone())[0] == "hellot"
+ cur = await conn.execute("select %b", ["hello"])
+ assert (await cur.fetchone())[0] == "hellob"
+ await conn.close()
+
+
+async def test_connect_context_copy(aconn_cls, dsn, aconn):
+ aconn.adapters.register_dumper(str, make_bin_dumper("b"))
+ aconn.adapters.register_dumper(str, make_dumper("t"))
+
+ aconn2 = await aconn_cls.connect(dsn, context=aconn)
+
+ cur = await aconn2.execute("select %s", ["hello"])
+ assert (await cur.fetchone())[0] == "hellot"
+ cur = await aconn2.execute("select %b", ["hello"])
+ assert (await cur.fetchone())[0] == "hellob"
+ await aconn2.close()
+
+
+async def test_cancel_closed(aconn):
+ await aconn.close()
+ aconn.cancel()
+
+
+async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811
+ got = []
+
+ def fake_connect_gen(conninfo, **kwargs):
+ got.append(conninfo)
+ 1 / 0
+
+ monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
+
+ with pytest.raises(ZeroDivisionError):
+ await psycopg.AsyncConnection.connect("host=foo.com")
+
+ assert len(got) == 1
+ want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+ assert conninfo_to_dict(got[0]) == want
diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py
new file mode 100644
index 0000000..e2c2c01
--- /dev/null
+++ b/tests/test_conninfo.py
@@ -0,0 +1,450 @@
+import socket
+import asyncio
+import datetime as dt
+
+import pytest
+
+import psycopg
+from psycopg import ProgrammingError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from psycopg.conninfo import resolve_hostaddr_async
+from psycopg._encodings import pg2pyenc
+
+from .fix_crdb import crdb_encoding
+
+snowman = "\u2603"
+
+
+class MyString(str):
+ pass
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs, exp",
+ [
+ ("", {}, ""),
+ ("dbname=foo", {}, "dbname=foo"),
+ ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"),
+ ("dbname=sony", {"password": ""}, "dbname=sony password="),
+ ("dbname=foo", {"dbname": "bar"}, "dbname=bar"),
+ ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"),
+ ("", {"dbname": "foo"}, "dbname=foo"),
+ ("", {"dbname": "foo", "user": None}, "dbname=foo"),
+ ("", {"dbname": "foo", "port": 15432}, "dbname=foo port=15432"),
+ ("", {"dbname": "a'b"}, r"dbname='a\'b'"),
+ (f"dbname={snowman}", {}, f"dbname={snowman}"),
+ ("", {"dbname": snowman}, f"dbname={snowman}"),
+ (
+ "postgresql://host1/test",
+ {"host": "host2"},
+ "dbname=test host=host2",
+ ),
+ (MyString(""), {}, ""),
+ ],
+)
+def test_make_conninfo(conninfo, kwargs, exp):
+ out = make_conninfo(conninfo, **kwargs)
+ assert conninfo_to_dict(out) == conninfo_to_dict(exp)
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs",
+ [
+ ("hello", {}),
+ ("dbname=foo bar", {}),
+ ("foo=bar", {}),
+ ("dbname=foo", {"bar": "baz"}),
+ ("postgresql://tester:secret@/test?port=5433=x", {}),
+ (f"{snowman}={snowman}", {}),
+ ],
+)
+def test_make_conninfo_bad(conninfo, kwargs):
+ with pytest.raises(ProgrammingError):
+ make_conninfo(conninfo, **kwargs)
+
+
+@pytest.mark.parametrize(
+ "conninfo, exp",
+ [
+ ("", {}),
+ ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}),
+ ("dbname=sony password=", {"dbname": "sony", "password": ""}),
+ ("dbname='foo bar'", {"dbname": "foo bar"}),
+ ("dbname='a\"b'", {"dbname": 'a"b'}),
+ (r"dbname='a\'b'", {"dbname": "a'b"}),
+ (r"dbname='a\\b'", {"dbname": r"a\b"}),
+ (f"dbname={snowman}", {"dbname": snowman}),
+ (
+ "postgresql://tester:secret@/test?port=5433",
+ {
+ "user": "tester",
+ "password": "secret",
+ "dbname": "test",
+ "port": "5433",
+ },
+ ),
+ ],
+)
+def test_conninfo_to_dict(conninfo, exp):
+ assert conninfo_to_dict(conninfo) == exp
+
+
+def test_no_munging():
+ dsnin = "dbname=a host=b user=c password=d"
+ dsnout = make_conninfo(dsnin)
+ assert dsnin == dsnout
+
+
+class TestConnectionInfo:
+ @pytest.mark.parametrize(
+ "attr",
+ [("dbname", "db"), "host", "hostaddr", "user", "password", "options"],
+ )
+ def test_attrs(self, conn, attr):
+ if isinstance(attr, tuple):
+ info_attr, pgconn_attr = attr
+ else:
+ info_attr = pgconn_attr = attr
+
+ if info_attr == "hostaddr" and psycopg.pq.version() < 120000:
+ pytest.skip("hostaddr not supported on libpq < 12")
+
+ info_val = getattr(conn.info, info_attr)
+ pgconn_val = getattr(conn.pgconn, pgconn_attr).decode()
+ assert info_val == pgconn_val
+
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ getattr(conn.info, info_attr)
+
+ @pytest.mark.libpq("< 12")
+ def test_hostaddr_not_supported(self, conn):
+ with pytest.raises(psycopg.NotSupportedError):
+ conn.info.hostaddr
+
+ def test_port(self, conn):
+ assert conn.info.port == int(conn.pgconn.port.decode())
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.info.port
+
+ def test_get_params(self, conn, dsn):
+ info = conn.info.get_parameters()
+ for k, v in conninfo_to_dict(dsn).items():
+ if k != "password":
+ assert info.get(k) == v
+ else:
+ assert k not in info
+
+ def test_dsn(self, conn, dsn):
+ dsn = conn.info.dsn
+ assert "password" not in dsn
+ for k, v in conninfo_to_dict(dsn).items():
+ if k != "password":
+ assert f"{k}=" in dsn
+
+ def test_get_params_env(self, conn_cls, dsn, monkeypatch):
+ dsn = conninfo_to_dict(dsn)
+ dsn.pop("application_name", None)
+
+ monkeypatch.delenv("PGAPPNAME", raising=False)
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name" not in conn.info.get_parameters()
+
+ monkeypatch.setenv("PGAPPNAME", "hello test")
+ with conn_cls.connect(**dsn) as conn:
+ assert conn.info.get_parameters()["application_name"] == "hello test"
+
+ def test_dsn_env(self, conn_cls, dsn, monkeypatch):
+ dsn = conninfo_to_dict(dsn)
+ dsn.pop("application_name", None)
+
+ monkeypatch.delenv("PGAPPNAME", raising=False)
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name=" not in conn.info.dsn
+
+ monkeypatch.setenv("PGAPPNAME", "hello test")
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name='hello test'" in conn.info.dsn
+
+ def test_status(self, conn):
+ assert conn.info.status.name == "OK"
+ conn.close()
+ assert conn.info.status.name == "BAD"
+
+ def test_transaction_status(self, conn):
+ assert conn.info.transaction_status.name == "IDLE"
+ conn.close()
+ assert conn.info.transaction_status.name == "UNKNOWN"
+
+ @pytest.mark.pipeline
+ def test_pipeline_status(self, conn):
+ assert not conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "OFF"
+ with conn.pipeline():
+ assert conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "ON"
+
+ @pytest.mark.libpq("< 14")
+ def test_pipeline_status_no_pipeline(self, conn):
+ assert not conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "OFF"
+
+ def test_no_password(self, dsn):
+ dsn2 = make_conninfo(dsn, password="the-pass-word")
+ pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode())
+ info = ConnectionInfo(pgconn)
+ assert info.password == "the-pass-word"
+ assert "password" not in info.get_parameters()
+ assert info.get_parameters()["dbname"] == info.dbname
+
+ def test_dsn_no_password(self, dsn):
+ dsn2 = make_conninfo(dsn, password="the-pass-word")
+ pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode())
+ info = ConnectionInfo(pgconn)
+ assert info.password == "the-pass-word"
+ assert "password" not in info.dsn
+ assert f"dbname={info.dbname}" in info.dsn
+
+ def test_parameter_status(self, conn):
+ assert conn.info.parameter_status("nosuchparam") is None
+ tz = conn.info.parameter_status("TimeZone")
+ assert tz and isinstance(tz, str)
+ assert tz == conn.execute("show timezone").fetchone()[0]
+
+ @pytest.mark.crdb("skip")
+ def test_server_version(self, conn):
+ assert conn.info.server_version == conn.pgconn.server_version
+
+ def test_error_message(self, conn):
+ assert conn.info.error_message == ""
+ with pytest.raises(psycopg.ProgrammingError) as ex:
+ conn.execute("wat")
+
+ assert conn.info.error_message
+ assert str(ex.value) in conn.info.error_message
+ assert ex.value.diag.severity in conn.info.error_message
+
+ conn.close()
+ assert "NULL" in conn.info.error_message
+
+ @pytest.mark.crdb_skip("backend pid")
+ def test_backend_pid(self, conn):
+ assert conn.info.backend_pid
+ assert conn.info.backend_pid == conn.pgconn.backend_pid
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.info.backend_pid
+
+ def test_timezone(self, conn):
+ conn.execute("set timezone to 'Europe/Rome'")
+ tz = conn.info.timezone
+ assert isinstance(tz, dt.tzinfo)
+ offset = tz.utcoffset(dt.datetime(2000, 1, 1))
+ assert offset and offset.total_seconds() == 3600
+ offset = tz.utcoffset(dt.datetime(2000, 7, 1))
+ assert offset and offset.total_seconds() == 7200
+
+ @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones")
+ def test_timezone_warn(self, conn, caplog):
+ conn.execute("set timezone to 'FOOBAR0'")
+ assert len(caplog.records) == 0
+ tz = conn.info.timezone
+ assert tz == dt.timezone.utc
+ assert len(caplog.records) == 1
+ assert "FOOBAR0" in caplog.records[0].message
+
+ conn.info.timezone
+ assert len(caplog.records) == 1
+
+ conn.execute("set timezone to 'FOOBAAR0'")
+ assert len(caplog.records) == 1
+ conn.info.timezone
+ assert len(caplog.records) == 2
+ assert "FOOBAAR0" in caplog.records[1].message
+
+ def test_encoding(self, conn):
+ enc = conn.execute("show client_encoding").fetchone()[0]
+ assert conn.info.encoding == pg2pyenc(enc.encode())
+
+ @pytest.mark.crdb("skip", reason="encoding not normalized")
+ @pytest.mark.parametrize(
+ "enc, out, codec",
+ [
+ ("utf8", "UTF8", "utf-8"),
+ ("utf-8", "UTF8", "utf-8"),
+ ("utf_8", "UTF8", "utf-8"),
+ ("eucjp", "EUC_JP", "euc_jp"),
+ ("euc-jp", "EUC_JP", "euc_jp"),
+ ("latin9", "LATIN9", "iso8859-15"),
+ ],
+ )
+ def test_normalize_encoding(self, conn, enc, out, codec):
+ conn.execute("select set_config('client_encoding', %s, false)", [enc])
+ assert conn.info.parameter_status("client_encoding") == out
+ assert conn.info.encoding == codec
+
+ @pytest.mark.parametrize(
+ "enc, out, codec",
+ [
+ ("utf8", "UTF8", "utf-8"),
+ ("utf-8", "UTF8", "utf-8"),
+ ("utf_8", "UTF8", "utf-8"),
+ crdb_encoding("eucjp", "EUC_JP", "euc_jp"),
+ crdb_encoding("euc-jp", "EUC_JP", "euc_jp"),
+ ],
+ )
+ def test_encoding_env_var(self, conn_cls, dsn, monkeypatch, enc, out, codec):
+ monkeypatch.setenv("PGCLIENTENCODING", enc)
+ with conn_cls.connect(dsn) as conn:
+ clienc = conn.info.parameter_status("client_encoding")
+ assert clienc
+ if conn.info.vendor == "PostgreSQL":
+ assert clienc == out
+ else:
+ assert clienc.replace("-", "").replace("_", "").upper() == out
+ assert conn.info.encoding == codec
+
+ @pytest.mark.crdb_skip("encoding")
+ def test_set_encoding_unsupported(self, conn):
+ cur = conn.cursor()
+ cur.execute("set client_encoding to EUC_TW")
+ with pytest.raises(psycopg.NotSupportedError):
+ cur.execute("select 'x'")
+
+ def test_vendor(self, conn):
+ assert conn.info.vendor
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ ("", "", None),
+ ("host='' user=bar", "host='' user=bar", None),
+ (
+ "host=127.0.0.1 user=bar",
+ "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 user=bar",
+ "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ {"PGHOST": "1.1.1.1,2.2.2.2"},
+ ),
+ (
+ "host=foo.com port=5432",
+ "host=foo.com port=5432",
+ {"PGHOSTADDR": "1.2.3.4"},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_no_resolve(
+ setpgenv, conninfo, want, env, fail_resolve
+):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5432,5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+ None,
+ ),
+ (
+ "host=foo.com,nosuchhost.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com, port=5432,5433",
+ "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+ None,
+ ),
+ (
+ "host=nosuchhost.com,foo.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ {},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, env",
+ [
+ ("host=bad1.com,bad2.com", None),
+ ("host=foo.com port=1,2", None),
+ ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+ ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.Error):
+ await resolve_hostaddr_async(params)
+
+
+@pytest.fixture
+async def fake_resolve(monkeypatch):
+ fake_hosts = {
+ "localhost": "127.0.0.1",
+ "foo.com": "1.1.1.1",
+ "qux.com": "2.2.2.2",
+ }
+
+ async def fake_getaddrinfo(host, port, **kwargs):
+ assert isinstance(port, int) or (isinstance(port, str) and port.isdigit())
+ try:
+ addr = fake_hosts[host]
+ except KeyError:
+ raise OSError(f"unknown test host: {host}")
+ else:
+ return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))]
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
+
+
+@pytest.fixture
+async def fail_resolve(monkeypatch):
+ async def fail_getaddrinfo(host, port, **kwargs):
+ pytest.fail(f"shouldn't try to resolve {host}")
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
diff --git a/tests/test_copy.py b/tests/test_copy.py
new file mode 100644
index 0000000..17cf2fc
--- /dev/null
+++ b/tests/test_copy.py
@@ -0,0 +1,889 @@
+import string
+import struct
+import hashlib
+from io import BytesIO, StringIO
+from random import choice, randrange
+from itertools import cycle
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg import errors as e
+from psycopg.pq import Format
+from psycopg.copy import Copy, LibpqWriter, QueuedLibpqDriver, FileWriter
+from psycopg.adapt import PyFormat
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import register_hstore
+from psycopg.types.numeric import Int4
+
+from .utils import eur, gc_collect, gc_count
+
+pytestmark = pytest.mark.crdb_skip("copy")
+
+sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
+sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
+sample_tabledef = "col1 serial primary key, col2 int, data text"
+
+sample_text = b"""\
+40010\t40020\thello
+40040\t\\N\tworld
+"""
+
+sample_binary_str = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+ bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
+]
+sample_binary = b"".join(sample_binary_rows)
+
+special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_read(conn, format):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ for row in want:
+ got = copy.read()
+ assert got == row
+ assert conn.info.transaction_status == conn.TransactionStatus.ACTIVE
+
+ assert copy.read() == b""
+ assert copy.read() == b""
+
+ assert copy.read() == b""
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_iter(conn, format, row_factory):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ rf = getattr(psycopg.rows, row_factory)
+ cur = conn.cursor(row_factory=rf)
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ assert list(copy) == want
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_no_result(conn, format, row_factory):
+ rf = getattr(psycopg.rows, row_factory)
+ cur = conn.cursor(row_factory=rf)
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"):
+ with pytest.raises(e.ProgrammingError):
+ cur.fetchone()
+
+
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+def test_read_rows(conn, format, typetype):
+ cur = conn.cursor()
+ with cur.copy(
+ f"""copy (
+ select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+ ) to stdout (format {format.name})"""
+ ) as copy:
+ copy.set_types(["int4", "text", "float8[]"])
+ row = copy.read_row()
+ assert copy.read_row() is None
+
+ assert row == (10, "hello", [0.0, 1.0])
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+def test_rows(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ rows = list(copy.rows())
+
+ assert rows == sample_records
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_set_custom_type(conn, hstore):
+ command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout"""
+ cur = conn.cursor()
+
+ with cur.copy(command) as copy:
+ rows = list(copy.rows())
+
+ assert rows == [('"a"=>"1", "b"=>"2"',)]
+
+ register_hstore(TypeInfo.fetch(conn, "hstore"), cur)
+ with cur.copy(command) as copy:
+ copy.set_types(["hstore"])
+ rows = list(copy.rows())
+
+ assert rows == [({"a": "1", "b": "2"},)]
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_allchars(conn, format):
+ cur = conn.cursor()
+ chars = list(map(chr, range(1, 256))) + [eur]
+ conn.execute("set client_encoding to utf8")
+ rows = []
+ query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+ chars, sql.SQL(format.name)
+ )
+ with cur.copy(query) as copy:
+ copy.set_types(["text"])
+ while True:
+ row = copy.read_row()
+ if not row:
+ break
+ assert len(row) == 1
+ rows.append(row[0])
+
+ assert rows == chars
+
+
+@pytest.mark.parametrize("format", Format)
+def test_read_row_notypes(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ rows = []
+ while True:
+ row = copy.read_row()
+ if not row:
+ break
+ rows.append(row)
+
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("format", Format)
+def test_rows_notypes(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ rows = list(copy.rows())
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("err", [-1, 1])
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_badntypes(conn, format, err):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ copy.set_types([0] * (len(sample_records[0]) + err))
+ with pytest.raises(e.ProgrammingError):
+ copy.read_row()
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_copy_in_buffers(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_buffers_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_bad_result(conn):
+ conn.autocommit = True
+
+ cur = conn.cursor()
+
+ with pytest.raises(e.SyntaxError):
+ with cur.copy("wat"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("select 1"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("reset timezone"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("copy (select 1) to stdout; select 1") as copy:
+ list(copy)
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("select 1; copy (select 1) to stdout"):
+ pass
+
+
+def test_copy_in_str(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text.decode())
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ with cur.copy("copy copy_in from stdin (format binary)") as copy:
+ copy.write(sample_text.decode())
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+def test_copy_big_size_record(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write_row([data])
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+def test_copy_big_size_block(conn, pytype):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write(copy_data)
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.parametrize("format", Format)
+def test_subclass_adapter(conn, format):
+ if format == Format.TEXT:
+ from psycopg.types.string import StrDumper as BaseDumper
+ else:
+ from psycopg.types.string import ( # type: ignore[no-redef]
+ StrBinaryDumper as BaseDumper,
+ )
+
+ class MyStrDumper(BaseDumper):
+ def dump(self, obj):
+ return super().dump(obj) * 2
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
+ copy.write_row(("hello",))
+
+ rec = cur.execute("select data from copy_in").fetchone()
+ assert rec[0] == "hellohello"
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_error_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ raise Exception("mannaggiamiseria")
+
+ assert "mannaggiamiseria" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_py_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_error_with_copy_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_copy_out_error_with_copy_not_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy (select generate_series(1, 1000000)) to stdout") as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_server_error(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DivisionByZero):
+ with cur.copy(
+ "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+ ) as copy:
+ for block in copy:
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_set_types(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_binary(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+ with cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
+ ) as copy:
+ for row in sample_records:
+ copy.write_row((None, row[2]))
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+def test_copy_in_allchars(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ conn.execute("set client_encoding to utf8")
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ for i in range(1, 256):
+ copy.write_row((i, None, chr(i)))
+ copy.write_row((ord(eur), None, eur))
+
+ data = cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ ).fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+def test_copy_in_format(conn):
+ file = BytesIO()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with Copy(cur, writer=FileWriter(file)) as copy:
+ for i in range(1, 256):
+ copy.write_row((i, chr(i)))
+
+ file.seek(0)
+ rows = file.read().split(b"\n")
+ assert not rows[-1]
+ del rows[-1]
+
+ for i, row in enumerate(rows, start=1):
+ fields = row.split(b"\t")
+ assert len(fields) == 2
+ assert int(fields[0].decode()) == i
+ if i in special_chars:
+ assert fields[1].decode() == f"\\{special_chars[i]}"
+ else:
+ assert fields[1].decode() == chr(i)
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_file_writer(conn, format, buffer):
+ file = BytesIO()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with Copy(cur, binary=format, writer=FileWriter(file)) as copy:
+ for record in sample_records:
+ copy.write_row(record)
+
+ file.seek(0)
+ want = globals()[buffer]
+ got = file.read()
+ assert got == want
+
+
+@pytest.mark.slow
+def test_copy_from_to(conn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+ f = BytesIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+def test_copy_from_to_bytes(conn, pytype):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(pytype(block.encode()))
+
+ gen.assert_data()
+
+ f = BytesIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+def test_copy_from_insane_size(conn):
+ # Trying to trigger a "would block" error
+ gen = DataGenerator(
+ conn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024
+ )
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+
+def test_copy_rowcount(conn):
+ gen = DataGenerator(conn, nrecs=3, srec=10)
+ gen.ensure_table()
+
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+ assert cur.rowcount == 3
+
+ gen = DataGenerator(conn, nrecs=2, srec=10, offset=3)
+ with cur.copy("copy copy_in from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == 2
+
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ pass
+ assert cur.rowcount == 5
+
+ with pytest.raises(e.BadCopyFileFormat):
+ with cur.copy("copy copy_in (id) from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == -1
+
+
+def test_copy_query(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ assert cur._query.query == b"copy (select 1) to stdout"
+ assert not cur._query.params
+ list(copy)
+
+
+def test_cant_reenter(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ list(copy)
+
+ with pytest.raises(TypeError):
+ with copy:
+ list(copy)
+
+
+def test_str(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ assert "[ACTIVE]" in str(copy)
+ list(copy)
+
+ assert "[INTRANS]" in str(copy)
+
+
+def test_description(conn):
+ with conn.cursor() as cur:
+ with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy:
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+ list(copy.rows())
+
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_worker_life(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=QueuedLibpqDriver(cur)
+ ) as copy:
+ assert not copy.writer._worker
+ copy.write(globals()[buffer])
+ assert copy.writer._worker
+
+ assert not copy.writer._worker
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_worker_error_propagated(conn, monkeypatch):
+ def copy_to_broken(pgconn, buffer):
+ raise ZeroDivisionError
+ yield
+
+ monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+ cur = conn.cursor()
+ cur.execute("create temp table wat (a text, b text)")
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy wat from stdin", writer=QueuedLibpqDriver(cur)) as copy:
+ copy.write("a,b")
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_connection_writer(conn, format, buffer):
+ cur = conn.cursor()
+ writer = LibpqWriter(cur)
+
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=writer
+ ) as copy:
+ assert copy.writer is writer
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
+def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ stmt = sql.SQL(
+ "copy (select {} from {} order by id) to stdout (format {})"
+ ).format(
+ sql.SQL(", ").join(faker.fields_names),
+ faker.table_name,
+ sql.SQL(fmt.name),
+ )
+
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+
+ if method == "read":
+ while True:
+ tmp = copy.read()
+ if not tmp:
+ break
+ elif method == "iter":
+ list(copy)
+ elif method == "row":
+ while True:
+ tmp = copy.read_row()
+ if tmp is None:
+ break
+ elif method == "rows":
+ list(copy.rows())
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin (format {})").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL(fmt.name),
+ )
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ copy.write_row(row)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("mode", ["row", "block", "binary"])
+def test_copy_table_across(conn_cls, dsn, faker, mode):
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2:
+ faker.table_name = sql.Identifier("copy_src")
+ conn1.execute(faker.drop_stmt)
+ conn1.execute(faker.create_stmt)
+ conn1.cursor().executemany(faker.insert_stmt, faker.records)
+
+ faker.table_name = sql.Identifier("copy_tgt")
+ conn2.execute(faker.drop_stmt)
+ conn2.execute(faker.create_stmt)
+
+ fmt = "(format binary)" if mode == "binary" else ""
+ with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1:
+ with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2:
+ if mode == "row":
+ for row in copy1.rows():
+ copy2.write_row(row)
+ else:
+ for data in copy1:
+ copy2.write(data)
+
+ recs = conn2.execute(faker.select_stmt).fetchall()
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+def py_to_raw(item, fmt):
+ """Convert from Python type to the expected result from the db"""
+ if fmt == Format.TEXT:
+ if isinstance(item, int):
+ return str(item)
+ else:
+ if isinstance(item, int):
+ # Assume int4
+ return struct.pack("!i", item)
+ elif isinstance(item, str):
+ return item.encode()
+ return item
+
+
+def ensure_table(cur, tabledef, name="copy_in"):
+ cur.execute(f"drop table if exists {name}")
+ cur.execute(f"create table {name} ({tabledef})")
+
+
+class DataGenerator:
+ def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
+ self.conn = conn
+ self.nrecs = nrecs
+ self.srec = srec
+ self.offset = offset
+ self.block_size = block_size
+
+ def ensure_table(self):
+ cur = self.conn.cursor()
+ ensure_table(cur, "id integer primary key, data text")
+
+ def records(self):
+ for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
+ s = c * self.srec
+ yield (i + self.offset, s)
+
+ def file(self):
+ f = StringIO()
+ for i, s in self.records():
+ f.write("%s\t%s\n" % (i, s))
+
+ f.seek(0)
+ return f
+
+ def blocks(self):
+ f = self.file()
+ while True:
+ block = f.read(self.block_size)
+ if not block:
+ break
+ yield block
+
+ def assert_data(self):
+ cur = self.conn.cursor()
+ cur.execute("select id, data from copy_in order by id")
+ for record in self.records():
+ assert record == cur.fetchone()
+
+ assert cur.fetchone() is None
+
+ def sha(self, f):
+ m = hashlib.sha256()
+ while True:
+ block = f.read()
+ if not block:
+ break
+ if isinstance(block, str):
+ block = block.encode()
+ m.update(block)
+ return m.hexdigest()
diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py
new file mode 100644
index 0000000..59389dd
--- /dev/null
+++ b/tests/test_copy_async.py
@@ -0,0 +1,892 @@
+import string
+import hashlib
+from io import BytesIO, StringIO
+from random import choice, randrange
+from itertools import cycle
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg import errors as e
+from psycopg.pq import Format
+from psycopg.copy import AsyncCopy
+from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter
+from psycopg.types import TypeInfo
+from psycopg.adapt import PyFormat
+from psycopg.types.hstore import register_hstore
+from psycopg.types.numeric import Int4
+
+from .utils import alist, eur, gc_collect, gc_count
+from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
+from .test_copy import sample_values, sample_records, sample_tabledef
+from .test_copy import py_to_raw, special_chars
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("copy"),
+]
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_out_read(aconn, format):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ for row in want:
+ got = await copy.read()
+ assert got == row
+ assert aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE
+
+ assert await copy.read() == b""
+ assert await copy.read() == b""
+
+ assert await copy.read() == b""
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_copy_out_iter(aconn, format, row_factory):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ rf = getattr(psycopg.rows, row_factory)
+ cur = aconn.cursor(row_factory=rf)
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ assert await alist(copy) == want
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_copy_out_no_result(aconn, format, row_factory):
+ rf = getattr(psycopg.rows, row_factory)
+ cur = aconn.cursor(row_factory=rf)
+ async with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"):
+ with pytest.raises(e.ProgrammingError):
+ await cur.fetchone()
+
+
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+async def test_read_rows(aconn, format, typetype):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"""copy (
+ select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+ ) to stdout (format {format.name})"""
+ ) as copy:
+ copy.set_types(["int4", "text", "float8[]"])
+ row = await copy.read_row()
+ assert (await copy.read_row()) is None
+
+ assert row == (10, "hello", [0.0, 1.0])
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_rows(aconn, format):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ copy.set_types("int4 int4 text".split())
+ rows = await alist(copy.rows())
+
+ assert rows == sample_records
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_set_custom_type(aconn, hstore):
+ command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout"""
+ cur = aconn.cursor()
+
+ async with cur.copy(command) as copy:
+ rows = await alist(copy.rows())
+
+ assert rows == [('"a"=>"1", "b"=>"2"',)]
+
+ register_hstore(await TypeInfo.fetch(aconn, "hstore"), cur)
+ async with cur.copy(command) as copy:
+ copy.set_types(["hstore"])
+ rows = await alist(copy.rows())
+
+ assert rows == [({"a": "1", "b": "2"},)]
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_out_allchars(aconn, format):
+ cur = aconn.cursor()
+ chars = list(map(chr, range(1, 256))) + [eur]
+ await aconn.execute("set client_encoding to utf8")
+ rows = []
+ query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+ chars, sql.SQL(format.name)
+ )
+ async with cur.copy(query) as copy:
+ copy.set_types(["text"])
+ while True:
+ row = await copy.read_row()
+ if not row:
+ break
+ assert len(row) == 1
+ rows.append(row[0])
+
+ assert rows == chars
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_read_row_notypes(aconn, format):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ rows = []
+ while True:
+ row = await copy.read_row()
+ if not row:
+ break
+ rows.append(row)
+
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_rows_notypes(aconn, format):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ rows = await alist(copy.rows())
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("err", [-1, 1])
+@pytest.mark.parametrize("format", Format)
+async def test_copy_out_badntypes(aconn, format, err):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ copy.set_types([0] * (len(sample_records[0]) + err))
+ with pytest.raises(e.ProgrammingError):
+ await copy.read_row()
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_copy_in_buffers(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_buffers_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_bad_result(aconn):
+ await aconn.set_autocommit(True)
+
+ cur = aconn.cursor()
+
+ with pytest.raises(e.SyntaxError):
+ async with cur.copy("wat"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("select 1"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("reset timezone"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("copy (select 1) to stdout; select 1") as copy:
+ await alist(copy)
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("select 1; copy (select 1) to stdout"):
+ pass
+
+
+async def test_copy_in_str(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text.decode())
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ async with cur.copy("copy copy_in from stdin (format binary)") as copy:
+ await copy.write(sample_text.decode())
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_empty(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+async def test_copy_big_size_record(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write_row([data])
+
+ await cur.execute("select data from copy_in limit 1")
+ assert await cur.fetchone() == (data,)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+async def test_copy_big_size_block(aconn, pytype):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write(copy_data)
+
+ await cur.execute("select data from copy_in limit 1")
+ assert await cur.fetchone() == (data,)
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_subclass_adapter(aconn, format):
+ if format == Format.TEXT:
+ from psycopg.types.string import StrDumper as BaseDumper
+ else:
+ from psycopg.types.string import ( # type: ignore[no-redef]
+ StrBinaryDumper as BaseDumper,
+ )
+
+ class MyStrDumper(BaseDumper):
+ def dump(self, obj):
+ return super().dump(obj) * 2
+
+ aconn.adapters.register_dumper(str, MyStrDumper)
+
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(
+ f"copy copy_in (data) from stdin (format {format.name})"
+ ) as copy:
+ await copy.write_row(("hello",))
+
+ await cur.execute("select data from copy_in")
+ rec = await cur.fetchone()
+ assert rec[0] == "hellohello"
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_error_empty(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ raise Exception("mannaggiamiseria")
+
+ assert "mannaggiamiseria" in str(exc.value)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_buffers_with_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_buffers_with_py_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_out_error_with_copy_finished(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
+ await copy.read_row()
+ 1 / 0
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_copy_out_error_with_copy_not_finished(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy(
+ "copy (select generate_series(1, 1000000)) to stdout"
+ ) as copy:
+ await copy.read_row()
+ 1 / 0
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_out_server_error(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(e.DivisionByZero):
+ async with cur.copy(
+ "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+ ) as copy:
+ async for block in copy:
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_set_types(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_binary(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+ async with cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
+ ) as copy:
+ for row in sample_records:
+ await copy.write_row((None, row[2]))
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+async def test_copy_in_allchars(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ await aconn.execute("set client_encoding to utf8")
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, None, chr(i)))
+ await copy.write_row((ord(eur), None, eur))
+
+ await cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ )
+ data = await cur.fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+async def test_copy_in_format(aconn):
+ file = BytesIO()
+ await aconn.execute("set client_encoding to utf8")
+ cur = aconn.cursor()
+ async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, chr(i)))
+
+ file.seek(0)
+ rows = file.read().split(b"\n")
+ assert not rows[-1]
+ del rows[-1]
+
+ for i, row in enumerate(rows, start=1):
+ fields = row.split(b"\t")
+ assert len(fields) == 2
+ assert int(fields[0].decode()) == i
+ if i in special_chars:
+ assert fields[1].decode() == f"\\{special_chars[i]}"
+ else:
+ assert fields[1].decode() == chr(i)
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_file_writer(aconn, format, buffer):
+ file = BytesIO()
+ await aconn.execute("set client_encoding to utf8")
+ cur = aconn.cursor()
+ async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy:
+ for record in sample_records:
+ await copy.write_row(record)
+
+ file.seek(0)
+ want = globals()[buffer]
+ got = file.read()
+ assert got == want
+
+
+@pytest.mark.slow
+async def test_copy_from_to(aconn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
+ await gen.ensure_table()
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+
+ await gen.assert_data()
+
+ f = BytesIO()
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+async def test_copy_from_to_bytes(aconn, pytype):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
+ await gen.ensure_table()
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(pytype(block.encode()))
+
+ await gen.assert_data()
+
+ f = BytesIO()
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+async def test_copy_from_insane_size(aconn):
+ # Trying to trigger a "would block" error
+ gen = DataGenerator(
+ aconn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024
+ )
+ await gen.ensure_table()
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+
+ await gen.assert_data()
+
+
+async def test_copy_rowcount(aconn):
+ gen = DataGenerator(aconn, nrecs=3, srec=10)
+ await gen.ensure_table()
+
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+ assert cur.rowcount == 3
+
+ gen = DataGenerator(aconn, nrecs=2, srec=10, offset=3)
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for rec in gen.records():
+ await copy.write_row(rec)
+ assert cur.rowcount == 2
+
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ pass
+ assert cur.rowcount == 5
+
+ with pytest.raises(e.BadCopyFileFormat):
+ async with cur.copy("copy copy_in (id) from stdin") as copy:
+ for rec in gen.records():
+ await copy.write_row(rec)
+ assert cur.rowcount == -1
+
+
+async def test_copy_query(aconn):
+ cur = aconn.cursor()
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ assert cur._query.query == b"copy (select 1) to stdout"
+ assert not cur._query.params
+ await alist(copy)
+
+
+async def test_cant_reenter(aconn):
+ cur = aconn.cursor()
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ await alist(copy)
+
+ with pytest.raises(TypeError):
+ async with copy:
+ await alist(copy)
+
+
+async def test_str(aconn):
+ cur = aconn.cursor()
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ assert "[ACTIVE]" in str(copy)
+ await alist(copy)
+
+ assert "[INTRANS]" in str(copy)
+
+
+async def test_description(aconn):
+ async with aconn.cursor() as cur:
+ async with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy:
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+ await alist(copy.rows())
+
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_worker_life(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(
+ f"copy copy_in from stdin (format {format.name})",
+ writer=AsyncQueuedLibpqWriter(cur),
+ ) as copy:
+ assert not copy.writer._worker
+ await copy.write(globals()[buffer])
+ assert copy.writer._worker
+
+ assert not copy.writer._worker
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_worker_error_propagated(aconn, monkeypatch):
+ def copy_to_broken(pgconn, buffer):
+ raise ZeroDivisionError
+ yield
+
+ monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+ cur = aconn.cursor()
+ await cur.execute("create temp table wat (a text, b text)")
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy(
+ "copy wat from stdin", writer=AsyncQueuedLibpqWriter(cur)
+ ) as copy:
+ await copy.write("a,b")
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_connection_writer(aconn, format, buffer):
+ cur = aconn.cursor()
+ writer = AsyncLibpqWriter(cur)
+
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=writer
+ ) as copy:
+ assert copy.writer is writer
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
+async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn:
+ async with conn.cursor(binary=fmt) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+ async with faker.find_insert_problem_async(conn):
+ await cur.executemany(faker.insert_stmt, faker.records)
+
+ stmt = sql.SQL(
+ "copy (select {} from {} order by id) to stdout (format {})"
+ ).format(
+ sql.SQL(", ").join(faker.fields_names),
+ faker.table_name,
+ sql.SQL(fmt.name),
+ )
+
+ async with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+
+ if method == "read":
+ while True:
+ tmp = await copy.read()
+ if not tmp:
+ break
+ elif method == "iter":
+ await alist(copy)
+ elif method == "row":
+ while True:
+ tmp = await copy.read_row()
+ if tmp is None:
+ break
+ elif method == "rows":
+ await alist(copy.rows())
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn:
+ async with conn.cursor(binary=fmt) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin (format {})").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL(fmt.name),
+ )
+ async with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ await copy.write_row(row)
+
+ await cur.execute(faker.select_stmt)
+ recs = await cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("mode", ["row", "block", "binary"])
+async def test_copy_table_across(aconn_cls, dsn, faker, mode):
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ connect = aconn_cls.connect
+ async with await connect(dsn) as conn1, await connect(dsn) as conn2:
+ faker.table_name = sql.Identifier("copy_src")
+ await conn1.execute(faker.drop_stmt)
+ await conn1.execute(faker.create_stmt)
+ await conn1.cursor().executemany(faker.insert_stmt, faker.records)
+
+ faker.table_name = sql.Identifier("copy_tgt")
+ await conn2.execute(faker.drop_stmt)
+ await conn2.execute(faker.create_stmt)
+
+ fmt = "(format binary)" if mode == "binary" else ""
+ async with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1:
+ async with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2:
+ if mode == "row":
+ async for row in copy1.rows():
+ await copy2.write_row(row)
+ else:
+ async for data in copy1:
+ await copy2.write(data)
+
+ cur = await conn2.execute(faker.select_stmt)
+ recs = await cur.fetchall()
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+async def ensure_table(cur, tabledef, name="copy_in"):
+ await cur.execute(f"drop table if exists {name}")
+ await cur.execute(f"create table {name} ({tabledef})")
+
+
+class DataGenerator:
+ def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
+ self.conn = conn
+ self.nrecs = nrecs
+ self.srec = srec
+ self.offset = offset
+ self.block_size = block_size
+
+ async def ensure_table(self):
+ cur = self.conn.cursor()
+ await ensure_table(cur, "id integer primary key, data text")
+
+ def records(self):
+ for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
+ s = c * self.srec
+ yield (i + self.offset, s)
+
+ def file(self):
+ f = StringIO()
+ for i, s in self.records():
+ f.write("%s\t%s\n" % (i, s))
+
+ f.seek(0)
+ return f
+
+ def blocks(self):
+ f = self.file()
+ while True:
+ block = f.read(self.block_size)
+ if not block:
+ break
+ yield block
+
+ async def assert_data(self):
+ cur = self.conn.cursor()
+ await cur.execute("select id, data from copy_in order by id")
+ for record in self.records():
+ assert record == await cur.fetchone()
+
+ assert await cur.fetchone() is None
+
+ def sha(self, f):
+ m = hashlib.sha256()
+ while True:
+ block = f.read()
+ if not block:
+ break
+ if isinstance(block, str):
+ block = block.encode()
+ m.update(block)
+ return m.hexdigest()
+
+
+class AsyncFileWriter(AsyncWriter):
+ def __init__(self, file):
+ self.file = file
+
+ async def write(self, data):
+ self.file.write(data)
diff --git a/tests/test_cursor.py b/tests/test_cursor.py
new file mode 100644
index 0000000..a667f4f
--- /dev/null
+++ b/tests/test_cursor.py
@@ -0,0 +1,942 @@
+import pickle
+import weakref
+import datetime as dt
+from typing import List, Union
+from contextlib import closing
+
+import pytest
+
+import psycopg
+from psycopg import pq, sql, rows
+from psycopg.adapt import PyFormat
+from psycopg.postgres import types as builtins
+from psycopg.rows import RowMaker
+
+from .utils import gc_collect, gc_count
+from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision
+
+
+def test_init(conn):
+ cur = psycopg.Cursor(conn)
+ cur.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ conn.row_factory = rows.dict_row
+ cur = psycopg.Cursor(conn)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_init_factory(conn):
+ cur = psycopg.Cursor(conn, row_factory=rows.dict_row)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_close(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.execute("select 'foo'")
+
+ cur.close()
+ assert cur.closed
+
+
+def test_cursor_close_fetchone(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ for _ in range(5):
+ cur.fetchone()
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchone()
+
+
+def test_cursor_close_fetchmany(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchmany(2)) == 2
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchmany(2)
+
+
+def test_cursor_close_fetchall(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchall()) == 10
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchall()
+
+
+def test_context(conn):
+ with conn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+def test_weakref(conn):
+ cur = conn.cursor()
+ w = weakref.ref(cur)
+ cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+def test_pgresult(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert cur.pgresult
+ cur.close()
+ assert not cur.pgresult
+
+
+def test_statusmessage(conn):
+ cur = conn.cursor()
+ assert cur.statusmessage is None
+
+ cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+def test_execute_many_results(conn):
+ cur = conn.cursor()
+ assert cur.nextset() is None
+
+ rv = cur.execute("select 'foo'; select generate_series(1,3)")
+ assert rv is cur
+ assert cur.fetchall() == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.nextset() is None
+
+ cur.close()
+ assert cur.nextset() is None
+
+
+def test_execute_sequence(conn):
+ cur = conn.cursor()
+ rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+def test_execute_empty_query(conn, query):
+ cur = conn.cursor()
+ cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+
+def test_execute_type_change(conn):
+ # issue #112
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.execute(sql, (1,))
+ cur.execute(sql, (100_000,))
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+def test_executemany_type_change(conn):
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.executemany(sql, [(1,), (100_000,)])
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+def test_execute_copy(conn, query):
+ cur = conn.cursor()
+ cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute(query)
+
+
+def test_fetchone(conn):
+ cur = conn.cursor()
+ cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = cur.fetchone()
+ assert row == (1, "foo", None)
+ row = cur.fetchone()
+ assert row is None
+
+
+def test_binary_cursor_execute(conn):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None])
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+def test_execute_binary(conn):
+ cur = conn.cursor()
+ cur.execute("select %s, %s", [1, None], binary=True)
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+def test_binary_cursor_text_override(conn):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None], binary=False)
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+def test_query_encode(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("select '\u20ac'").fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+def test_query_badenc(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute("select '\u20ac'")
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop table if exists execmany;
+ create table execmany (id serial primary key, num integer, data text)
+ """
+ )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+ cur = svcconn.cursor()
+ cur.execute("truncate table execmany")
+
+
+def test_executemany(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(11, "hello"), (21, "world")]
+
+
+def test_executemany_no_data(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+def test_executemany_rowcount(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+def test_executemany_returning(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.fetchone() == (10,)
+ assert cur.nextset()
+ assert cur.fetchone() == (20,)
+ assert cur.nextset() is None
+
+
+def test_executemany_returning_discard(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ assert cur.nextset() is None
+
+
+def test_executemany_no_result(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+def test_executemany_rowcount_no_hit(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+def test_executemany_badquery(conn, query):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_executemany_null_first(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table testmany (a bigint, b bigint)")
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+def test_rowcount(conn):
+ cur = conn.cursor()
+
+ cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ cur.execute("show timezone")
+ assert cur.rowcount == 1
+
+ cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+
+def test_rownumber(conn):
+ cur = conn.cursor()
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+def test_rownumber_none(conn, query):
+ cur = conn.cursor()
+ cur.execute(query)
+ assert cur.rownumber is None
+
+
+def test_rownumber_mixed(conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+ )
+ assert cur.rownumber == 0
+ assert cur.fetchone() == (1,)
+ assert cur.rownumber == 1
+ assert cur.fetchone() == (2,)
+ assert cur.rownumber == 2
+ cur.nextset()
+ assert cur.rownumber is None
+ cur.nextset()
+ assert cur.rownumber == 0
+ assert cur.fetchone() == (4,)
+ assert cur.rownumber == 1
+
+
+def test_iter(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ assert list(cur) == [(1,), (2,), (3,)]
+
+
+def test_iter_stop(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ for rec in cur:
+ assert rec == (1,)
+ break
+
+ for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert cur.fetchone() == (3,)
+ assert list(cur) == []
+
+
+def test_row_factory(conn):
+ cur = conn.cursor(row_factory=my_row_factory)
+
+ cur.execute("reset search_path")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+ cur.execute("select 'foo' as bar")
+ (r,) = cur.fetchone()
+ assert r == "FOObar"
+
+ cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert cur.fetchall() == [["Yy", "Zz"]]
+
+ cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert cur.fetchone() == {"y": "y", "z": "z"}
+
+
+def test_row_factory_none(conn):
+ cur = conn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ r = cur.execute("select 1 as a, 2 as b").fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+def test_bad_row_factory(conn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = conn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = conn.cursor(row_factory=broken_maker)
+ cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ cur.fetchone()
+
+
+def test_scroll(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.scroll(0)
+
+ cur.execute("select generate_series(0,9)")
+ cur.scroll(2)
+ assert cur.fetchone() == (2,)
+ cur.scroll(2)
+ assert cur.fetchone() == (5,)
+ cur.scroll(2, mode="relative")
+ assert cur.fetchone() == (8,)
+ cur.scroll(-1)
+ assert cur.fetchone() == (8,)
+ cur.scroll(-2)
+ assert cur.fetchone() == (7,)
+ cur.scroll(2, mode="absolute")
+ assert cur.fetchone() == (2,)
+
+ # on the boundary
+ cur.scroll(0, mode="absolute")
+ assert cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ cur.scroll(-1, mode="absolute")
+
+ cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(-1)
+
+ cur.scroll(9, mode="absolute")
+ assert cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ cur.scroll(10, mode="absolute")
+
+ cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ cur.scroll(1, "wat")
+
+
+def test_query_params_execute(conn):
+ cur = conn.cursor()
+ assert cur._query is None
+
+ cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select $1, $2::text"
+ assert cur._query.params == [b"1", None]
+
+ cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select $1::int"
+ assert cur._query.params == [b"wat"]
+
+
+def test_query_params_executemany(conn):
+ cur = conn.cursor()
+
+ cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select $1, $2"
+ assert cur._query.params == [b"3", b"4"]
+
+
+def test_stream(conn):
+ cur = conn.cursor()
+ recs = []
+ for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+def test_stream_sql(conn):
+ cur = conn.cursor()
+ recs = list(
+ cur.stream(
+ sql.SQL(
+ "select i, '2021-01-01'::date + i from generate_series(1, {}) as i"
+ ).format(2)
+ )
+ )
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+def test_stream_row_factory(conn):
+ cur = conn.cursor(row_factory=rows.dict_row)
+ it = iter(cur.stream("select generate_series(1,2) as a"))
+ assert next(it)["a"] == 1
+ cur.row_factory = rows.namedtuple_row
+ assert next(it).a == 2
+
+
+def test_stream_no_row(conn):
+ cur = conn.cursor()
+ recs = list(cur.stream("select generate_series(2,1) as a"))
+ assert recs == []
+
+
+@pytest.mark.crdb_skip("no col query")
+def test_stream_no_col(conn):
+ cur = conn.cursor()
+ recs = list(cur.stream("select"))
+ assert recs == [()]
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_stream_badq ()",
+ "copy (select 1) to stdout",
+ "wat?",
+ ],
+)
+def test_stream_badquery(conn, query):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ for rec in cur.stream(query):
+ pass
+
+
+def test_stream_error_tx(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ for rec in cur.stream("wat"):
+ pass
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_stream_error_notx(conn):
+ conn.autocommit = True
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ for rec in cur.stream("wat"):
+ pass
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+
+def test_stream_error_python_to_consume(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with closing(cur.stream("select generate_series(1, 10000)")) as gen:
+ for rec in gen:
+ 1 / 0
+ assert conn.info.transaction_status in (
+ conn.TransactionStatus.INTRANS,
+ conn.TransactionStatus.INERROR,
+ )
+
+
+def test_stream_error_python_consumed(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ gen = cur.stream("select 1")
+ for rec in gen:
+ 1 / 0
+ gen.close()
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_stream_close(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.OperationalError):
+ for rec in cur.stream("select generate_series(1, 3)"):
+ if rec[0] == 1:
+ conn.close()
+ else:
+ assert False
+
+ assert conn.closed
+
+
+def test_stream_binary_cursor(conn):
+ cur = conn.cursor(binary=True)
+ recs = []
+ for rec in cur.stream("select x::int4 from generate_series(1, 2) x"):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+def test_stream_execute_binary(conn):
+ cur = conn.cursor()
+ recs = []
+ for rec in cur.stream("select x::int4 from generate_series(1, 2) x", binary=True):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+def test_stream_binary_cursor_text_override(conn):
+ cur = conn.cursor(binary=True)
+ recs = []
+ for rec in cur.stream("select generate_series(1, 2)", binary=False):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode()
+
+ assert recs == [(1,), (2,)]
+
+
+class TestColumn:
+ def test_description_attribs(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ assert len(curs.description) == 3
+ for c in curs.description:
+ len(c) == 7 # DBAPI happy
+ for i, a in enumerate(
+ """
+ name type_code display_size internal_size precision scale null_ok
+ """.split()
+ ):
+ assert c[i] == getattr(c, a)
+
+ # Won't fill them up
+ assert c.null_ok is None
+
+ c = curs.description[0]
+ assert c.name == "pi"
+ assert c.type_code == builtins["numeric"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision == 10
+ assert c.scale == 2
+
+ c = curs.description[1]
+ assert c.name == "hi"
+ assert c.type_code == builtins["text"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision is None
+ assert c.scale is None
+
+ c = curs.description[2]
+ assert c.name == "now"
+ assert c.type_code == builtins["date"].oid
+ assert c.display_size is None
+ if is_crdb(conn):
+ assert c.internal_size == 16
+ else:
+ assert c.internal_size == 4
+ assert c.precision is None
+ assert c.scale is None
+
+ def test_description_slice(self, conn):
+ curs = conn.cursor()
+ curs.execute("select 1::int as a")
+ curs.description[0][0:2] == ("a", 23)
+
+ @pytest.mark.parametrize(
+ "type, precision, scale, dsize, isize",
+ [
+ ("text", None, None, None, None),
+ ("varchar", None, None, None, None),
+ ("varchar(42)", None, None, 42, None),
+ ("int4", None, None, None, 4),
+ ("numeric", None, None, None, None),
+ ("numeric(10)", 10, 0, None, None),
+ ("numeric(10, 3)", 10, 3, None, None),
+ ("time", None, None, None, 8),
+ crdb_time_precision("time(4)", 4, None, None, 8),
+ crdb_time_precision("time(10)", 6, None, None, 8),
+ ],
+ )
+ def test_details(self, conn, type, precision, scale, dsize, isize):
+ cur = conn.cursor()
+ cur.execute(f"select null::{type}")
+ col = cur.description[0]
+ repr(col)
+ assert col.precision == precision
+ assert col.scale == scale
+ assert col.display_size == dsize
+ assert col.internal_size == isize
+
+ def test_pickle(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ description = curs.description
+ pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
+ unpickled = pickle.loads(pickled)
+ assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]
+
+ @pytest.mark.crdb_skip("no col query")
+ def test_no_col_query(self, conn):
+ cur = conn.execute("select")
+ assert cur.description == []
+ assert cur.fetchall() == [()]
+
+ def test_description_closed_connection(self, conn):
+ # If we have reasons to break this test we will (e.g. we really need
+ # the connection). In #172 it fails just by accident.
+ cur = conn.execute("select 1::int4 as foo")
+ conn.close()
+ assert len(cur.description) == 1
+ col = cur.description[0]
+ assert col.name == "foo"
+ assert col.type_code == 23
+
+ def test_name_not_a_name(self, conn):
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "foo-bar"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_name_encode(self, conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "\u20ac"
+
+
+def test_str(conn):
+ cur = conn.cursor()
+ assert "psycopg.Cursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+ faker.format = fmt
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ def work():
+ with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True):
+ with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ cur.fetchall()
+ elif fetch == "iter":
+ for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+def my_row_factory(
+ cursor: Union[psycopg.Cursor[List[str]], psycopg.AsyncCursor[List[str]]]
+) -> RowMaker[List[str]]:
+ if cursor.description is not None:
+ titles = [c.name for c in cursor.description]
+
+ def mkrow(values):
+ return [f"{value.upper()}{title}" for title, value in zip(titles, values)]
+
+ return mkrow
+ else:
+ return rows.no_result
diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py
new file mode 100644
index 0000000..ac3fdeb
--- /dev/null
+++ b/tests/test_cursor_async.py
@@ -0,0 +1,802 @@
+import pytest
+import weakref
+import datetime as dt
+from typing import List
+
+import psycopg
+from psycopg import pq, sql, rows
+from psycopg.adapt import PyFormat
+
+from .utils import gc_collect, gc_count
+from .test_cursor import my_row_factory
+from .test_cursor import execmany, _execmany # noqa: F401
+from .fix_crdb import crdb_encoding
+
+execmany = execmany # avoid F811 underneath
+pytestmark = pytest.mark.asyncio
+
+
+async def test_init(aconn):
+ cur = psycopg.AsyncCursor(aconn)
+ await cur.execute("select 1")
+ assert (await cur.fetchone()) == (1,)
+
+ aconn.row_factory = rows.dict_row
+ cur = psycopg.AsyncCursor(aconn)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_init_factory(aconn):
+ cur = psycopg.AsyncCursor(aconn, row_factory=rows.dict_row)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_close(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.execute("select 'foo'")
+
+ await cur.close()
+ assert cur.closed
+
+
+async def test_cursor_close_fetchone(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ for _ in range(5):
+ await cur.fetchone()
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchone()
+
+
+async def test_cursor_close_fetchmany(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchmany(2)) == 2
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchmany(2)
+
+
+async def test_cursor_close_fetchall(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchall()) == 10
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchall()
+
+
+async def test_context(aconn):
+ async with aconn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+async def test_weakref(aconn):
+ cur = aconn.cursor()
+ w = weakref.ref(cur)
+ await cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+async def test_pgresult(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert cur.pgresult
+ await cur.close()
+ assert not cur.pgresult
+
+
+async def test_statusmessage(aconn):
+ cur = aconn.cursor()
+ assert cur.statusmessage is None
+
+ await cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ await cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+async def test_execute_many_results(aconn):
+ cur = aconn.cursor()
+ assert cur.nextset() is None
+
+ rv = await cur.execute("select 'foo'; select generate_series(1,3)")
+ assert rv is cur
+ assert (await cur.fetchall()) == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert (await cur.fetchall()) == [(1,), (2,), (3,)]
+ assert cur.rowcount == 3
+ assert cur.nextset() is None
+
+ await cur.close()
+ assert cur.nextset() is None
+
+
+async def test_execute_sequence(aconn):
+ cur = aconn.cursor()
+ rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+async def test_execute_empty_query(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+
+
+async def test_execute_type_change(aconn):
+ # issue #112
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.execute(sql, (1,))
+ await cur.execute(sql, (100_000,))
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+async def test_executemany_type_change(aconn):
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.executemany(sql, [(1,), (100_000,)])
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+async def test_execute_copy(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute(query)
+
+
+async def test_fetchone(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = await cur.fetchone()
+ assert row == (1, "foo", None)
+ row = await cur.fetchone()
+ assert row is None
+
+
+async def test_binary_cursor_execute(aconn):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None])
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+async def test_execute_binary(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s, %s", [1, None], binary=True)
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+async def test_binary_cursor_text_override(aconn):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None], binary=False)
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+async def test_query_encode(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ await cur.execute("select '\u20ac'")
+ (res,) = await cur.fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+async def test_query_badenc(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ await cur.execute("select '\u20ac'")
+
+
+async def test_executemany(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(10, "hello"), (20, "world")]
+
+
+async def test_executemany_name(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(11, "hello"), (21, "world")]
+
+
+async def test_executemany_no_data(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+async def test_executemany_rowcount(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+async def test_executemany_returning(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert (await cur.fetchone()) == (10,)
+ assert cur.nextset()
+ assert (await cur.fetchone()) == (20,)
+ assert cur.nextset() is None
+
+
+async def test_executemany_returning_discard(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ assert cur.nextset() is None
+
+
+async def test_executemany_no_result(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+async def test_executemany_rowcount_no_hit(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ await cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ await cur.executemany(
+ "delete from execmany where id = %s returning num", [(-1,), (-2,)]
+ )
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+async def test_executemany_badquery(aconn, query):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+async def test_executemany_null_first(aconn, fmt_in):
+ cur = aconn.cursor()
+ await cur.execute("create table testmany (a bigint, b bigint)")
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+async def test_rowcount(aconn):
+ cur = aconn.cursor()
+
+ await cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ await cur.execute("show timezone")
+ assert cur.rowcount == 1
+
+ await cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ await cur.execute(
+ "insert into test_rowcount_notuples select generate_series(1, 42)"
+ )
+ assert cur.rowcount == 42
+
+
+async def test_rownumber(aconn):
+ cur = aconn.cursor()
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ async for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(await cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+async def test_rownumber_none(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute(query)
+ assert cur.rownumber is None
+
+
+async def test_rownumber_mixed(aconn):
+ cur = aconn.cursor()
+ await cur.execute(
+ """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+ )
+ assert cur.rownumber == 0
+ assert await cur.fetchone() == (1,)
+ assert cur.rownumber == 1
+ assert await cur.fetchone() == (2,)
+ assert cur.rownumber == 2
+ cur.nextset()
+ assert cur.rownumber is None
+ cur.nextset()
+ assert cur.rownumber == 0
+ assert await cur.fetchone() == (4,)
+ assert cur.rownumber == 1
+
+
+async def test_iter(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ res = []
+ async for rec in cur:
+ res.append(rec)
+ assert res == [(1,), (2,), (3,)]
+
+
+async def test_iter_stop(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ async for rec in cur:
+ assert rec == (1,)
+ break
+
+ async for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert (await cur.fetchone()) == (3,)
+ async for rec in cur:
+ assert False
+
+
+async def test_row_factory(aconn):
+ cur = aconn.cursor(row_factory=my_row_factory)
+ await cur.execute("select 'foo' as bar")
+ (r,) = await cur.fetchone()
+ assert r == "FOObar"
+
+ await cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert await cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert await cur.fetchall() == [["Yy", "Zz"]]
+
+ await cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert await cur.fetchone() == {"y": "y", "z": "z"}
+
+
+async def test_row_factory_none(aconn):
+ cur = aconn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ await cur.execute("select 1 as a, 2 as b")
+ r = await cur.fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+async def test_bad_row_factory(aconn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = aconn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ await cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = aconn.cursor(row_factory=broken_maker)
+ await cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ await cur.fetchone()
+
+
+async def test_scroll(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.scroll(0)
+
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-1)
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-2)
+ assert await cur.fetchone() == (7,)
+ await cur.scroll(2, mode="absolute")
+ assert await cur.fetchone() == (2,)
+
+ # on the boundary
+ await cur.scroll(0, mode="absolute")
+ assert await cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ await cur.scroll(-1, mode="absolute")
+
+ await cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(-1)
+
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ await cur.scroll(10, mode="absolute")
+
+ await cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(1, "wat")
+
+
+async def test_query_params_execute(aconn):
+ cur = aconn.cursor()
+ assert cur._query is None
+
+ await cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select $1, $2::text"
+ assert cur._query.params == [b"1", None]
+
+ await cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ await cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select $1::int"
+ assert cur._query.params == [b"wat"]
+
+
+async def test_query_params_executemany(aconn):
+ cur = aconn.cursor()
+
+ await cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select $1, $2"
+ assert cur._query.params == [b"3", b"4"]
+
+
+async def test_stream(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+async def test_stream_sql(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ sql.SQL(
+ "select i, '2021-01-01'::date + i from generate_series(1, {}) as i"
+ ).format(2)
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+async def test_stream_row_factory(aconn):
+ cur = aconn.cursor(row_factory=rows.dict_row)
+ ait = cur.stream("select generate_series(1,2) as a")
+ assert (await ait.__anext__())["a"] == 1
+ cur.row_factory = rows.namedtuple_row
+ assert (await ait.__anext__()).a == 2
+
+
+async def test_stream_no_row(aconn):
+ cur = aconn.cursor()
+ recs = [rec async for rec in cur.stream("select generate_series(2,1) as a")]
+ assert recs == []
+
+
+@pytest.mark.crdb_skip("no col query")
+async def test_stream_no_col(aconn):
+ cur = aconn.cursor()
+ recs = [rec async for rec in cur.stream("select")]
+ assert recs == [()]
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_stream_badq ()",
+ "copy (select 1) to stdout",
+ "wat?",
+ ],
+)
+async def test_stream_badquery(aconn, query):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ async for rec in cur.stream(query):
+ pass
+
+
+async def test_stream_error_tx(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ async for rec in cur.stream("wat"):
+ pass
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_stream_error_notx(aconn):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ async for rec in cur.stream("wat"):
+ pass
+ assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE
+
+
+async def test_stream_error_python_to_consume(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ gen = cur.stream("select generate_series(1, 10000)")
+ async for rec in gen:
+ 1 / 0
+
+ await gen.aclose()
+ assert aconn.info.transaction_status in (
+ aconn.TransactionStatus.INTRANS,
+ aconn.TransactionStatus.INERROR,
+ )
+
+
+async def test_stream_error_python_consumed(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ gen = cur.stream("select 1")
+ async for rec in gen:
+ 1 / 0
+
+ await gen.aclose()
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_stream_close(aconn):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.OperationalError):
+ async for rec in cur.stream("select generate_series(1, 3)"):
+ if rec[0] == 1:
+ await aconn.close()
+ else:
+ assert False
+
+ assert aconn.closed
+
+
+async def test_stream_binary_cursor(aconn):
+ cur = aconn.cursor(binary=True)
+ recs = []
+ async for rec in cur.stream("select x::int4 from generate_series(1, 2) x"):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+async def test_stream_execute_binary(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ "select x::int4 from generate_series(1, 2) x", binary=True
+ ):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+async def test_stream_binary_cursor_text_override(aconn):
+ cur = aconn.cursor(binary=True)
+ recs = []
+ async for rec in cur.stream("select generate_series(1, 2)", binary=False):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode()
+
+ assert recs == [(1,), (2,)]
+
+
+async def test_str(aconn):
+ cur = aconn.cursor()
+ assert "psycopg.AsyncCursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ await cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ await cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+ faker.format = fmt
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn, conn.transaction(
+ force_rollback=True
+ ):
+ async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+ async with faker.find_insert_problem_async(conn):
+ await cur.executemany(faker.insert_stmt, faker.records)
+ await cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = await cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = await cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ await cur.fetchall()
+ elif fetch == "iter":
+ async for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
diff --git a/tests/test_dns.py b/tests/test_dns.py
new file mode 100644
index 0000000..f50092f
--- /dev/null
+++ b/tests/test_dns.py
@@ -0,0 +1,27 @@
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+pytestmark = [pytest.mark.dns]
+
+
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_warning(recwarn):
+ import_dnspython()
+ conninfo = "dbname=foo"
+ params = conninfo_to_dict(conninfo)
+ params = await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined]
+ params
+ )
+ assert conninfo_to_dict(conninfo) == params
+ assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message)
+
+
+def import_dnspython():
+ try:
+ import dns.rdtypes.IN.A # noqa: F401
+ except ImportError:
+ pytest.skip("dnspython package not available")
+
+ import psycopg._dns # noqa: F401
diff --git a/tests/test_dns_srv.py b/tests/test_dns_srv.py
new file mode 100644
index 0000000..15b3706
--- /dev/null
+++ b/tests/test_dns_srv.py
@@ -0,0 +1,149 @@
+from typing import List, Union
+
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from .test_dns import import_dnspython
+
+pytestmark = [pytest.mark.dns]
+
+samples_ok = [
+ ("", "", None),
+ ("host=_pg._tcp.foo.com", "host=db1.example.com port=5432", None),
+ ("", "host=db1.example.com port=5432", {"PGHOST": "_pg._tcp.foo.com"}),
+ (
+ "host=foo.com,_pg._tcp.foo.com",
+ "host=foo.com,db1.example.com port=,5432",
+ None,
+ ),
+ (
+ "host=_pg._tcp.dot.com,foo.com,_pg._tcp.foo.com",
+ "host=foo.com,db1.example.com port=,5432",
+ None,
+ ),
+ (
+ "host=_pg._tcp.bar.com",
+ "host=db1.example.com,db4.example.com,db3.example.com,db2.example.com"
+ " port=5432,5432,5433,5432",
+ None,
+ ),
+ (
+ "host=service.foo.com port=srv",
+ "host=service.example.com port=15432",
+ None,
+ ),
+ # No resolution
+ (
+ "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+ "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+]
+
+
+@pytest.mark.flakey("random weight order, might cause wrong order")
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+def test_srv(conninfo, want, env, fake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ params = psycopg._dns.resolve_srv(params) # type: ignore[attr-defined]
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+async def test_srv_async(conninfo, want, env, afake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ params = await (
+ psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined]
+ )
+ assert conninfo_to_dict(want) == params
+
+
+samples_bad = [
+ ("host=_pg._tcp.dot.com", None),
+ ("host=_pg._tcp.foo.com port=1,2", None),
+]
+
+
+@pytest.mark.parametrize("conninfo, env", samples_bad)
+def test_srv_bad(conninfo, env, fake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.OperationalError):
+ psycopg._dns.resolve_srv(params) # type: ignore[attr-defined]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo, env", samples_bad)
+async def test_srv_bad_async(conninfo, env, afake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.OperationalError):
+ await psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined]
+
+
+@pytest.fixture
+def fake_srv(monkeypatch):
+ f = get_fake_srv_function(monkeypatch)
+ monkeypatch.setattr(
+ psycopg._dns.resolver, # type: ignore[attr-defined]
+ "resolve",
+ f,
+ )
+
+
+@pytest.fixture
+def afake_srv(monkeypatch):
+ f = get_fake_srv_function(monkeypatch)
+
+ async def af(qname, rdtype):
+ return f(qname, rdtype)
+
+ monkeypatch.setattr(
+ psycopg._dns.async_resolver, # type: ignore[attr-defined]
+ "resolve",
+ af,
+ )
+
+
+def get_fake_srv_function(monkeypatch):
+ import_dnspython()
+
+ from dns.rdtypes.IN.A import A
+ from dns.rdtypes.IN.SRV import SRV
+ from dns.exception import DNSException
+
+ fake_hosts = {
+ ("_pg._tcp.dot.com", "SRV"): ["0 0 5432 ."],
+ ("_pg._tcp.foo.com", "SRV"): ["0 0 5432 db1.example.com."],
+ ("_pg._tcp.bar.com", "SRV"): [
+ "1 0 5432 db2.example.com.",
+ "1 255 5433 db3.example.com.",
+ "0 0 5432 db1.example.com.",
+ "1 65535 5432 db4.example.com.",
+ ],
+ ("service.foo.com", "SRV"): ["0 0 15432 service.example.com."],
+ }
+
+ def fake_srv_(qname, rdtype):
+ try:
+ ans = fake_hosts[qname, rdtype]
+ except KeyError:
+ raise DNSException(f"unknown test host: {qname} {rdtype}")
+ rv: List[Union[A, SRV]] = []
+
+ if rdtype == "A":
+ for entry in ans:
+ rv.append(A("IN", "A", entry))
+ else:
+ for entry in ans:
+ pri, w, port, target = entry.split()
+ rv.append(SRV("IN", "SRV", int(pri), int(w), int(port), target))
+
+ return rv
+
+ return fake_srv_
diff --git a/tests/test_encodings.py b/tests/test_encodings.py
new file mode 100644
index 0000000..113f0e3
--- /dev/null
+++ b/tests/test_encodings.py
@@ -0,0 +1,57 @@
+import codecs
+import pytest
+
+import psycopg
+from psycopg import _encodings as encodings
+
+
+def test_names_normalised():
+ for name in encodings._py_codecs.values():
+ assert codecs.lookup(name).name == name
+
+
+@pytest.mark.parametrize(
+ "pyenc, pgenc",
+ [
+ ("ascii", "SQL_ASCII"),
+ ("utf8", "UTF8"),
+ ("utf-8", "UTF8"),
+ ("uTf-8", "UTF8"),
+ ("latin9", "LATIN9"),
+ ("iso8859-15", "LATIN9"),
+ ],
+)
+def test_py2pg(pyenc, pgenc):
+ assert encodings.py2pgenc(pyenc) == pgenc.encode()
+
+
+@pytest.mark.parametrize(
+ "pyenc, pgenc",
+ [
+ ("ascii", "SQL_ASCII"),
+ ("utf-8", "UTF8"),
+ ("iso8859-15", "LATIN9"),
+ ],
+)
+def test_pg2py(pyenc, pgenc):
+ assert encodings.pg2pyenc(pgenc.encode()) == pyenc
+
+
+@pytest.mark.parametrize("pgenc", ["MULE_INTERNAL", "EUC_TW"])
+def test_pg2py_missing(pgenc):
+ with pytest.raises(psycopg.NotSupportedError):
+ encodings.pg2pyenc(pgenc.encode())
+
+
+@pytest.mark.parametrize(
+ "conninfo, pyenc",
+ [
+ ("", "utf-8"),
+ ("user=foo, dbname=bar", "utf-8"),
+ ("user=foo, dbname=bar, client_encoding=EUC_JP", "euc_jp"),
+ ("user=foo, dbname=bar, client_encoding=euc-jp", "euc_jp"),
+ ("user=foo, dbname=bar, client_encoding=WAT", "utf-8"),
+ ],
+)
+def test_conninfo_encoding(conninfo, pyenc):
+ assert encodings.conninfo_encoding(conninfo) == pyenc
diff --git a/tests/test_errors.py b/tests/test_errors.py
new file mode 100644
index 0000000..23ad314
--- /dev/null
+++ b/tests/test_errors.py
@@ -0,0 +1,309 @@
+import pickle
+from typing import List
+from weakref import ref
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import errors as e
+
+from .utils import eur, gc_collect
+from .fix_crdb import is_crdb
+
+
+@pytest.mark.crdb_skip("severity_nonlocalized")
+def test_error_diag(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select 1 from wat")
+
+ exc = excinfo.value
+ diag = exc.diag
+ assert diag.sqlstate == "42P01"
+ assert diag.severity_nonlocalized == "ERROR"
+
+
+def test_diag_all_attrs(pgconn):
+ res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR)
+ diag = e.Diagnostic(res)
+ for d in pq.DiagnosticField:
+ val = getattr(diag, d.name.lower())
+ assert val is None or isinstance(val, str)
+
+
+def test_diag_right_attr(pgconn, monkeypatch):
+ res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR)
+ diag = e.Diagnostic(res)
+
+ to_check: pq.DiagnosticField
+ checked: List[pq.DiagnosticField] = []
+
+ def check_val(self, v):
+ nonlocal to_check
+ assert to_check == v
+ checked.append(v)
+ return None
+
+ monkeypatch.setattr(e.Diagnostic, "_error_message", check_val)
+
+ for to_check in pq.DiagnosticField:
+ getattr(diag, to_check.name.lower())
+
+ assert len(checked) == len(pq.DiagnosticField)
+
+
+def test_diag_attr_values(conn):
+ if is_crdb(conn):
+ conn.execute("set experimental_enable_temp_tables = 'on'")
+ conn.execute(
+ """
+ create temp table test_exc (
+ data int constraint chk_eq1 check (data = 1)
+ )"""
+ )
+ with pytest.raises(e.Error) as exc:
+ conn.execute("insert into test_exc values(2)")
+ diag = exc.value.diag
+ assert diag.sqlstate == "23514"
+ assert diag.constraint_name == "chk_eq1"
+ if not is_crdb(conn):
+ assert diag.table_name == "test_exc"
+ assert diag.schema_name and diag.schema_name[:7] == "pg_temp"
+ assert diag.severity_nonlocalized == "ERROR"
+
+
+@pytest.mark.crdb_skip("do")
+@pytest.mark.parametrize("enc", ["utf8", "latin9"])
+def test_diag_encoding(conn, enc):
+ msgs = []
+ conn.pgconn.exec_(b"set client_min_messages to notice")
+ conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary))
+ conn.execute(f"set client_encoding to {enc}")
+ cur = conn.cursor()
+ cur.execute("do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql")
+ assert msgs == [f"hello {eur}"]
+
+
+@pytest.mark.crdb_skip("do")
+@pytest.mark.parametrize("enc", ["utf8", "latin9"])
+def test_error_encoding(conn, enc):
+ with conn.transaction():
+ conn.execute(f"set client_encoding to {enc}")
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute(
+ """
+ do $$begin
+ execute format('insert into "%s" values (1)', chr(8364));
+ end$$ language plpgsql;
+ """
+ )
+
+ diag = excinfo.value.diag
+ assert diag.message_primary and f'"{eur}"' in diag.message_primary
+ assert diag.sqlstate == "42P01"
+
+
+def test_exception_class(conn):
+ cur = conn.cursor()
+
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select * from nonexist")
+
+ assert isinstance(excinfo.value, e.UndefinedTable)
+ assert isinstance(excinfo.value, conn.ProgrammingError)
+
+
+def test_exception_class_fallback(conn):
+ cur = conn.cursor()
+
+ x = e._sqlcodes.pop("42P01")
+ try:
+ with pytest.raises(e.Error) as excinfo:
+ cur.execute("select * from nonexist")
+ finally:
+ e._sqlcodes["42P01"] = x
+
+ assert type(excinfo.value) is conn.ProgrammingError
+
+
+def test_lookup():
+ assert e.lookup("42P01") is e.UndefinedTable
+ assert e.lookup("42p01") is e.UndefinedTable
+ assert e.lookup("UNDEFINED_TABLE") is e.UndefinedTable
+ assert e.lookup("undefined_table") is e.UndefinedTable
+
+ with pytest.raises(KeyError):
+ e.lookup("XXXXX")
+
+
+def test_error_sqlstate():
+ assert e.Error.sqlstate is None
+ assert e.ProgrammingError.sqlstate is None
+ assert e.UndefinedTable.sqlstate == "42P01"
+
+
+def test_error_pickle(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select 1 from wat")
+
+ exc = pickle.loads(pickle.dumps(excinfo.value))
+ assert isinstance(exc, e.UndefinedTable)
+ assert exc.diag.sqlstate == "42P01"
+
+
+def test_diag_pickle(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select 1 from wat")
+
+ diag1 = excinfo.value.diag
+ diag2 = pickle.loads(pickle.dumps(diag1))
+
+ assert isinstance(diag2, type(diag1))
+ for f in pq.DiagnosticField:
+ assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower())
+
+ assert diag2.sqlstate == "42P01"
+
+
+@pytest.mark.slow
+def test_diag_survives_cursor(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.Error) as exc:
+ cur.execute("select * from nosuchtable")
+
+ diag = exc.value.diag
+ del exc
+ w = ref(cur)
+ del cur
+ gc_collect()
+ assert w() is None
+ assert diag.sqlstate == "42P01"
+
+
+def test_diag_independent(conn):
+ conn.autocommit = True
+ cur = conn.cursor()
+
+ with pytest.raises(e.Error) as exc1:
+ cur.execute("l'acqua e' poca e 'a papera nun galleggia")
+
+ with pytest.raises(e.Error) as exc2:
+ cur.execute("select level from water where ducks > 1")
+
+ assert exc1.value.diag.sqlstate == "42601"
+ assert exc2.value.diag.sqlstate == "42P01"
+
+
+@pytest.mark.crdb_skip("deferrable")
+def test_diag_from_commit(conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create temp table test_deferred (
+ data int primary key,
+ ref int references test_deferred (data)
+ deferrable initially deferred)
+ """
+ )
+ cur.execute("insert into test_deferred values (1,2)")
+ with pytest.raises(e.Error) as exc:
+ conn.commit()
+
+ assert exc.value.diag.sqlstate == "23503"
+
+
+@pytest.mark.asyncio
+@pytest.mark.crdb_skip("deferrable")
+async def test_diag_from_commit_async(aconn):
+ cur = aconn.cursor()
+ await cur.execute(
+ """
+ create temp table test_deferred (
+ data int primary key,
+ ref int references test_deferred (data)
+ deferrable initially deferred)
+ """
+ )
+ await cur.execute("insert into test_deferred values (1,2)")
+ with pytest.raises(e.Error) as exc:
+ await aconn.commit()
+
+ assert exc.value.diag.sqlstate == "23503"
+
+
+def test_query_context(conn):
+ with pytest.raises(e.Error) as exc:
+ conn.execute("select * from wat")
+
+ s = str(exc.value)
+ if not is_crdb(conn):
+ assert "from wat" in s, s
+ assert exc.value.diag.message_primary
+ assert exc.value.diag.message_primary in s
+ assert "ERROR" not in s
+ assert not s.endswith("\n")
+
+
+@pytest.mark.crdb_skip("do")
+def test_unknown_sqlstate(conn):
+ code = "PXX99"
+ with pytest.raises(KeyError):
+ e.lookup(code)
+
+ with pytest.raises(e.ProgrammingError) as excinfo:
+ conn.execute(
+ f"""
+ do $$begin
+ raise exception 'made up code' using errcode = '{code}';
+ end$$ language plpgsql
+ """
+ )
+ exc = excinfo.value
+ assert exc.diag.sqlstate == code
+ assert exc.sqlstate == code
+ # Survives pickling too
+ pexc = pickle.loads(pickle.dumps(exc))
+ assert pexc.sqlstate == code
+
+
+def test_pgconn_error(conn_cls):
+ with pytest.raises(psycopg.OperationalError) as excinfo:
+ conn_cls.connect("dbname=nosuchdb")
+
+ exc = excinfo.value
+ assert exc.pgconn
+ assert exc.pgconn.db == b"nosuchdb"
+
+
+def test_pgconn_error_pickle(conn_cls):
+ with pytest.raises(psycopg.OperationalError) as excinfo:
+ conn_cls.connect("dbname=nosuchdb")
+
+ exc = pickle.loads(pickle.dumps(excinfo.value))
+ assert exc.pgconn is None
+
+
+def test_pgresult(conn):
+ with pytest.raises(e.DatabaseError) as excinfo:
+ conn.execute("select 1 from wat")
+
+ exc = excinfo.value
+ assert exc.pgresult
+ assert exc.pgresult.error_field(pq.DiagnosticField.SQLSTATE) == b"42P01"
+
+
+def test_pgresult_pickle(conn):
+ with pytest.raises(e.DatabaseError) as excinfo:
+ conn.execute("select 1 from wat")
+
+ exc = pickle.loads(pickle.dumps(excinfo.value))
+ assert exc.pgresult is None
+ assert exc.diag.sqlstate == "42P01"
+
+
+def test_blank_sqlstate(conn):
+ assert e.get_base_exception("") is e.DatabaseError
diff --git a/tests/test_generators.py b/tests/test_generators.py
new file mode 100644
index 0000000..8aba73f
--- /dev/null
+++ b/tests/test_generators.py
@@ -0,0 +1,156 @@
+from collections import deque
+from functools import partial
+from typing import List
+
+import pytest
+
+import psycopg
+from psycopg import waiting
+from psycopg import pq
+
+
+@pytest.fixture
+def pipeline(pgconn):
+ nb, pgconn.nonblocking = pgconn.nonblocking, True
+ assert pgconn.nonblocking
+ pgconn.enter_pipeline_mode()
+ yield
+ if pgconn.pipeline_status:
+ pgconn.exit_pipeline_mode()
+ pgconn.nonblocking = nb
+
+
+def _run_pipeline_communicate(pgconn, generators, commands, expected_statuses):
+ actual_statuses: List[pq.ExecStatus] = []
+ while len(actual_statuses) != len(expected_statuses):
+ if commands:
+ gen = generators.pipeline_communicate(pgconn, commands)
+ results = waiting.wait(gen, pgconn.socket)
+ for (result,) in results:
+ actual_statuses.append(result.status)
+ else:
+ gen = generators.fetch_many(pgconn)
+ results = waiting.wait(gen, pgconn.socket)
+ for result in results:
+ actual_statuses.append(result.status)
+
+ assert actual_statuses == expected_statuses
+
+
+@pytest.mark.pipeline
+def test_pipeline_communicate_multi_pipeline(pgconn, pipeline, generators):
+ commands = deque(
+ [
+ partial(pgconn.send_query_params, b"select 1", None),
+ pgconn.pipeline_sync,
+ partial(pgconn.send_query_params, b"select 2", None),
+ pgconn.pipeline_sync,
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+
+
+@pytest.mark.pipeline
+def test_pipeline_communicate_no_sync(pgconn, pipeline, generators):
+ numqueries = 10
+ commands = deque(
+ [partial(pgconn.send_query_params, b"select repeat('xyzxz', 12)", None)]
+ * numqueries
+ + [pgconn.send_flush_request]
+ )
+ expected_statuses = [pq.ExecStatus.TUPLES_OK] * numqueries
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+
+
+@pytest.fixture
+def pipeline_demo(pgconn):
+ assert pgconn.pipeline_status == 0
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(
+ b"CREATE UNLOGGED TABLE pg_pipeline(" b" id serial primary key, itemno integer)"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ yield "pg_pipeline"
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+
+# TODOCRDB: 1 doesn't get rolled back. Open a ticket?
+@pytest.mark.pipeline
+@pytest.mark.crdb("skip", reason="pipeline aborted")
+def test_pipeline_communicate_abort(pgconn, pipeline_demo, pipeline, generators):
+ insert_sql = b"insert into pg_pipeline(itemno) values ($1)"
+ commands = deque(
+ [
+ partial(pgconn.send_query_params, insert_sql, [b"1"]),
+ partial(pgconn.send_query_params, b"select no_such_function(1)", None),
+ partial(pgconn.send_query_params, insert_sql, [b"2"]),
+ pgconn.pipeline_sync,
+ partial(pgconn.send_query_params, insert_sql, [b"3"]),
+ pgconn.pipeline_sync,
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.COMMAND_OK,
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_SYNC,
+ pq.ExecStatus.COMMAND_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+ pgconn.exit_pipeline_mode()
+ res = pgconn.exec_(b"select itemno from pg_pipeline order by itemno")
+ assert res.ntuples == 1
+ assert res.get_value(0, 0) == b"3"
+
+
+@pytest.fixture
+def pipeline_uniqviol(pgconn):
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ assert pgconn.pipeline_status == 0
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline_uniqviol")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(
+ b"CREATE UNLOGGED TABLE pg_pipeline_uniqviol("
+ b" id bigint primary key, idata bigint)"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(b"BEGIN")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.prepare(
+ b"insertion",
+ b"insert into pg_pipeline_uniqviol values ($1, $2) returning id",
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ return "pg_pipeline_uniqviol"
+
+
+def test_pipeline_communicate_uniqviol(pgconn, pipeline_uniqviol, pipeline, generators):
+ commands = deque(
+ [
+ partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"2", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"3", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"4", b"2"]),
+ partial(pgconn.send_query_params, b"commit", None),
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
diff --git a/tests/test_module.py b/tests/test_module.py
new file mode 100644
index 0000000..794ef0f
--- /dev/null
+++ b/tests/test_module.py
@@ -0,0 +1,57 @@
+import pytest
+
+from psycopg._cmodule import _psycopg
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want_conninfo",
+ [
+ ((), {}, ""),
+ (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"),
+ ((), {"port": 15432}, "port=15432"),
+ ((), {"user": "foo", "dbname": None}, "user=foo"),
+ ],
+)
+def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
+ # Check the main args passing from psycopg.connect to the conn generator
+ # Details of the params manipulation are in test_conninfo.
+ import psycopg.connection
+
+ orig_connect = psycopg.connection.connect # type: ignore
+
+ got_conninfo = None
+
+ def mock_connect(conninfo):
+ nonlocal got_conninfo
+ got_conninfo = conninfo
+ return orig_connect(dsn)
+
+ monkeypatch.setattr(psycopg.connection, "connect", mock_connect)
+
+ conn = psycopg.connect(*args, **kwargs)
+ assert got_conninfo == want_conninfo
+ conn.close()
+
+
+def test_version(mypy):
+ cp = mypy.run_on_source(
+ """\
+from psycopg import __version__
+assert __version__
+"""
+ )
+ assert not cp.stdout
+
+
+@pytest.mark.skipif(_psycopg is None, reason="C module test")
+def test_version_c(mypy):
+ # can be psycopg_c, psycopg_binary
+ cpackage = _psycopg.__name__.split(".")[0]
+
+ cp = mypy.run_on_source(
+ f"""\
+from {cpackage} import __version__
+assert __version__
+"""
+ )
+ assert not cp.stdout
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
new file mode 100644
index 0000000..56fe598
--- /dev/null
+++ b/tests/test_pipeline.py
@@ -0,0 +1,577 @@
+import logging
+import concurrent.futures
+from typing import Any
+from operator import attrgetter
+from itertools import groupby
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import errors as e
+
+pytestmark = [
+ pytest.mark.pipeline,
+ pytest.mark.skipif("not psycopg.Pipeline.is_supported()"),
+]
+
+pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted")
+
+
+def test_repr(conn):
+ with conn.pipeline() as p:
+ assert "psycopg.Pipeline" in repr(p)
+ assert "[IDLE, pipeline=ON]" in repr(p)
+
+ conn.close()
+ assert "[BAD]" in repr(p)
+
+
+def test_connection_closed(conn):
+ conn.close()
+ with pytest.raises(e.OperationalError):
+ with conn.pipeline():
+ pass
+
+
+def test_pipeline_status(conn: psycopg.Connection[Any]) -> None:
+ assert conn._pipeline is None
+ with conn.pipeline() as p:
+ assert conn._pipeline is p
+ assert p.status == pq.PipelineStatus.ON
+ assert p.status == pq.PipelineStatus.OFF
+ assert not conn._pipeline
+
+
+def test_pipeline_reenter(conn: psycopg.Connection[Any]) -> None:
+ with conn.pipeline() as p1:
+ with conn.pipeline() as p2:
+ assert p2 is p1
+ assert p1.status == pq.PipelineStatus.ON
+ assert p2 is p1
+ assert p2.status == pq.PipelineStatus.ON
+ assert conn._pipeline is None
+ assert p1.status == pq.PipelineStatus.OFF
+
+
+def test_pipeline_broken_conn_exit(conn: psycopg.Connection[Any]) -> None:
+ with pytest.raises(e.OperationalError):
+ with conn.pipeline():
+ conn.execute("select 1")
+ conn.close()
+ closed = True
+
+ assert closed
+
+
+def test_pipeline_exit_error_noclobber(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ with conn.pipeline():
+ conn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 1
+
+
+def test_pipeline_exit_error_noclobber_nested(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ with conn.pipeline():
+ with conn.pipeline():
+ conn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 2
+
+
+def test_pipeline_exit_sync_trace(conn, trace):
+ t = trace.trace(conn)
+ with conn.pipeline():
+ pass
+ conn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 1
+
+
+def test_pipeline_nested_sync_trace(conn, trace):
+ t = trace.trace(conn)
+ with conn.pipeline():
+ with conn.pipeline():
+ pass
+ conn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 2
+
+
+def test_cursor_stream(conn):
+ with conn.pipeline(), conn.cursor() as cur:
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.stream("select 1").__next__()
+
+
+def test_server_cursor(conn):
+ with conn.cursor(name="pipeline") as cur, conn.pipeline():
+ with pytest.raises(psycopg.NotSupportedError):
+ cur.execute("select 1")
+
+
+def test_cannot_insert_multiple_commands(conn):
+ with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)):
+ with conn.pipeline():
+ conn.execute("select 1; select 2")
+
+
+def test_copy(conn):
+ with conn.pipeline():
+ cur = conn.cursor()
+ with pytest.raises(e.NotSupportedError):
+ with cur.copy("copy (select 1) to stdout"):
+ pass
+
+
+def test_pipeline_processed_at_exit(conn):
+ with conn.cursor() as cur:
+ with conn.pipeline() as p:
+ cur.execute("select 1")
+
+ assert len(p.result_queue) == 1
+
+ assert cur.fetchone() == (1,)
+
+
+def test_pipeline_errors_processed_at_exit(conn):
+ conn.autocommit = True
+ with pytest.raises(e.UndefinedTable):
+ with conn.pipeline():
+ conn.execute("select * from nosuchtable")
+ conn.execute("create table voila ()")
+ cur = conn.execute(
+ "select count(*) from pg_tables where tablename = %s", ("voila",)
+ )
+ (count,) = cur.fetchone()
+ assert count == 0
+
+
+def test_pipeline(conn):
+ with conn.pipeline() as p:
+ c1 = conn.cursor()
+ c2 = conn.cursor()
+ c1.execute("select 1")
+ c2.execute("select 2")
+
+ assert len(p.result_queue) == 2
+
+ (r1,) = c1.fetchone()
+ assert r1 == 1
+
+ (r2,) = c2.fetchone()
+ assert r2 == 2
+
+
+def test_autocommit(conn):
+ conn.autocommit = True
+ with conn.pipeline(), conn.cursor() as c:
+ c.execute("select 1")
+
+ (r,) = c.fetchone()
+ assert r == 1
+
+
+def test_pipeline_aborted(conn):
+ conn.autocommit = True
+ with conn.pipeline() as p:
+ c1 = conn.execute("select 1")
+ with pytest.raises(e.UndefinedTable):
+ conn.execute("select * from doesnotexist").fetchone()
+ with pytest.raises(e.PipelineAborted):
+ conn.execute("select 'aborted'").fetchone()
+ # Sync restore the connection in usable state.
+ p.sync()
+ c2 = conn.execute("select 2")
+
+ (r,) = c1.fetchone()
+ assert r == 1
+
+ (r,) = c2.fetchone()
+ assert r == 2
+
+
+def test_pipeline_commit_aborted(conn):
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ with conn.pipeline():
+ conn.execute("select error")
+ conn.execute("create table voila ()")
+ conn.commit()
+
+
+def test_sync_syncs_results(conn):
+ with conn.pipeline() as p:
+ cur = conn.execute("select 1")
+ assert cur.statusmessage is None
+ p.sync()
+ assert cur.statusmessage == "SELECT 1"
+
+
+def test_sync_syncs_errors(conn):
+ conn.autocommit = True
+ with conn.pipeline() as p:
+ conn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ p.sync()
+
+
+@pipeline_aborted
+def test_errors_raised_on_commit(conn):
+ with conn.pipeline():
+ conn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ conn.commit()
+ conn.rollback()
+ cur1 = conn.execute("select 1")
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+def test_errors_raised_on_transaction_exit(conn):
+ here = False
+ with conn.pipeline():
+ with pytest.raises(e.UndefinedTable):
+ with conn.transaction():
+ conn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = conn.execute("select 1")
+ assert here
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+def test_errors_raised_on_nested_transaction_exit(conn):
+ here = False
+ with conn.pipeline():
+ with conn.transaction():
+ with pytest.raises(e.UndefinedTable):
+ with conn.transaction():
+ conn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = conn.execute("select 1")
+ assert here
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+def test_implicit_transaction(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+ conn.execute("select 'before'")
+ # Transaction is ACTIVE because previous command is not completed
+ # since we have not fetched its results.
+ assert conn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE
+ # Upon entering the nested pipeline through "with transaction():", a
+ # sync() is emitted to restore the transaction state to IDLE, as
+ # expected to emit a BEGIN.
+ with conn.transaction():
+ conn.execute("select 'tx'")
+ cur = conn.execute("select 'after'")
+ assert cur.fetchone() == ("after",)
+
+
+@pytest.mark.crdb_skip("deferrable")
+def test_error_on_commit(conn):
+ conn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ conn.commit()
+
+ with conn.pipeline():
+ conn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ conn.commit()
+ cur1 = conn.execute("select 1")
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+def test_fetch_no_result(conn):
+ with conn.pipeline():
+ cur = conn.cursor()
+ with pytest.raises(e.ProgrammingError):
+ cur.fetchone()
+
+
+def test_executemany(conn):
+ conn.autocommit = True
+ conn.execute("drop table if exists execmanypipeline")
+ conn.execute(
+ "create unlogged table execmanypipeline ("
+ " id serial primary key, num integer)"
+ )
+ with conn.pipeline(), conn.cursor() as cur:
+ cur.executemany(
+ "insert into execmanypipeline(num) values (%s) returning num",
+ [(10,), (20,)],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.fetchone() == (10,)
+ assert cur.nextset()
+ assert cur.fetchone() == (20,)
+ assert cur.nextset() is None
+
+
+def test_executemany_no_returning(conn):
+ conn.autocommit = True
+ conn.execute("drop table if exists execmanypipelinenoreturning")
+ conn.execute(
+ "create unlogged table execmanypipelinenoreturning ("
+ " id serial primary key, num integer)"
+ )
+ with conn.pipeline(), conn.cursor() as cur:
+ cur.executemany(
+ "insert into execmanypipelinenoreturning(num) values (%s)",
+ [(10,), (20,)],
+ returning=False,
+ )
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ cur.fetchone()
+ assert cur.nextset() is None
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ cur.fetchone()
+ assert cur.nextset() is None
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+def test_executemany_trace(conn, trace):
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("create temp table trace (id int)")
+ t = trace.trace(conn)
+ with conn.pipeline():
+ cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ conn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"]
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+def test_executemany_trace_returning(conn, trace):
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("create temp table trace (id int)")
+ t = trace.trace(conn)
+ with conn.pipeline():
+ cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ conn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"] * 3
+ assert items[-2].direction == "F" # last 2 items are F B
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+def test_prepared(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ c1 = conn.execute("select %s::int", [10], prepare=True)
+ c2 = conn.execute(
+ "select count(*) from pg_prepared_statements where name != ''"
+ )
+
+ (r,) = c1.fetchone()
+ assert r == 10
+
+ (r,) = c2.fetchone()
+ assert r == 1
+
+
+def test_auto_prepare(conn):
+ conn.autocommit = True
+ conn.prepared_threshold = 5
+ with conn.pipeline():
+ cursors = [
+ conn.execute("select count(*) from pg_prepared_statements where name != ''")
+ for i in range(10)
+ ]
+
+ assert len(conn._prepared._names) == 1
+
+ res = [c.fetchone()[0] for c in cursors]
+ assert res == [0] * 5 + [1] * 5
+
+
+def test_transaction(conn):
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ with conn.pipeline():
+ with conn.transaction():
+ cur = conn.execute("select 'tx'")
+
+ (r,) = cur.fetchone()
+ assert r == "tx"
+
+ with conn.transaction():
+ cur = conn.execute("select 'rb'")
+ raise psycopg.Rollback()
+
+ (r,) = cur.fetchone()
+ assert r == "rb"
+
+ assert not notices
+
+
+def test_transaction_nested(conn):
+ with conn.pipeline():
+ with conn.transaction():
+ outer = conn.execute("select 'outer'")
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ inner = conn.execute("select 'inner'")
+ 1 / 0
+
+ (r,) = outer.fetchone()
+ assert r == "outer"
+ (r,) = inner.fetchone()
+ assert r == "inner"
+
+
+def test_transaction_nested_no_statement(conn):
+ with conn.pipeline():
+ with conn.transaction():
+ with conn.transaction():
+ cur = conn.execute("select 1")
+
+ (r,) = cur.fetchone()
+ assert r == 1
+
+
+def test_outer_transaction(conn):
+ with conn.transaction():
+ conn.execute("drop table if exists outertx")
+ with conn.transaction():
+ with conn.pipeline():
+ conn.execute("create table outertx as (select 1)")
+ cur = conn.execute("select * from outertx")
+ (r,) = cur.fetchone()
+ assert r == 1
+ cur = conn.execute("select count(*) from pg_tables where tablename = 'outertx'")
+ assert cur.fetchone()[0] == 1
+
+
+def test_outer_transaction_error(conn):
+ with conn.transaction():
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ with conn.pipeline():
+ conn.execute("select error")
+ conn.execute("create table voila ()")
+
+
+def test_rollback_explicit(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ with pytest.raises(e.DivisionByZero):
+ cur = conn.execute("select 1 / %s", [0])
+ cur.fetchone()
+ conn.rollback()
+ conn.execute("select 1")
+
+
+def test_rollback_transaction(conn):
+ conn.autocommit = True
+ with pytest.raises(e.DivisionByZero):
+ with conn.pipeline():
+ with conn.transaction():
+ cur = conn.execute("select 1 / %s", [0])
+ cur.fetchone()
+ conn.execute("select 1")
+
+
+def test_message_0x33(conn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ conn.autocommit = True
+ with conn.pipeline():
+ cur = conn.execute("select 'test'")
+ assert cur.fetchone() == ("test",)
+
+ assert not notices
+
+
+def test_transaction_state_implicit_begin(conn, trace):
+ # Regression test to ensure that the transaction state is correct after
+ # the implicit BEGIN statement (in non-autocommit mode).
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+ t = trace.trace(conn)
+ with conn.pipeline():
+ conn.execute("select 'x'").fetchone()
+ conn.execute("select 'y'")
+ assert not notices
+ assert [
+ e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0]
+ ] == [b' "" "BEGIN" 0']
+
+
+def test_concurrency(conn):
+ with conn.transaction():
+ conn.execute("drop table if exists pipeline_concurrency")
+ conn.execute("drop table if exists accessed")
+ with conn.transaction():
+ conn.execute(
+ "create unlogged table pipeline_concurrency ("
+ " id serial primary key,"
+ " value integer"
+ ")"
+ )
+ conn.execute("create unlogged table accessed as (select now() as value)")
+
+ def update(value):
+ cur = conn.execute(
+ "insert into pipeline_concurrency(value) values (%s) returning value",
+ (value,),
+ )
+ conn.execute("update accessed set value = now()")
+ return cur
+
+ conn.autocommit = True
+
+ (before,) = conn.execute("select value from accessed").fetchone()
+
+ values = range(1, 10)
+ with conn.pipeline():
+ with concurrent.futures.ThreadPoolExecutor() as e:
+ cursors = e.map(update, values, timeout=len(values))
+ assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
+
+ (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone()
+ assert s == sum(values)
+ (after,) = conn.execute("select value from accessed").fetchone()
+ assert after > before
diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py
new file mode 100644
index 0000000..2e743cf
--- /dev/null
+++ b/tests/test_pipeline_async.py
@@ -0,0 +1,586 @@
+import asyncio
+import logging
+from typing import Any
+from operator import attrgetter
+from itertools import groupby
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import errors as e
+
+from .test_pipeline import pipeline_aborted
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.pipeline,
+ pytest.mark.skipif("not psycopg.AsyncPipeline.is_supported()"),
+]
+
+
+async def test_repr(aconn):
+ async with aconn.pipeline() as p:
+ assert "psycopg.AsyncPipeline" in repr(p)
+ assert "[IDLE, pipeline=ON]" in repr(p)
+
+ await aconn.close()
+ assert "[BAD]" in repr(p)
+
+
+async def test_connection_closed(aconn):
+ await aconn.close()
+ with pytest.raises(e.OperationalError):
+ async with aconn.pipeline():
+ pass
+
+
+async def test_pipeline_status(aconn: psycopg.AsyncConnection[Any]) -> None:
+ assert aconn._pipeline is None
+ async with aconn.pipeline() as p:
+ assert aconn._pipeline is p
+ assert p.status == pq.PipelineStatus.ON
+ assert p.status == pq.PipelineStatus.OFF
+ assert not aconn._pipeline
+
+
+async def test_pipeline_reenter(aconn: psycopg.AsyncConnection[Any]) -> None:
+ async with aconn.pipeline() as p1:
+ async with aconn.pipeline() as p2:
+ assert p2 is p1
+ assert p1.status == pq.PipelineStatus.ON
+ assert p2 is p1
+ assert p2.status == pq.PipelineStatus.ON
+ assert aconn._pipeline is None
+ assert p1.status == pq.PipelineStatus.OFF
+
+
+async def test_pipeline_broken_conn_exit(aconn: psycopg.AsyncConnection[Any]) -> None:
+ with pytest.raises(e.OperationalError):
+ async with aconn.pipeline():
+ await aconn.execute("select 1")
+ await aconn.close()
+ closed = True
+
+ assert closed
+
+
+async def test_pipeline_exit_error_noclobber(aconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.pipeline():
+ await aconn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 1
+
+
+async def test_pipeline_exit_error_noclobber_nested(aconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.pipeline():
+ async with aconn.pipeline():
+ await aconn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 2
+
+
+async def test_pipeline_exit_sync_trace(aconn, trace):
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ pass
+ await aconn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 1
+
+
+async def test_pipeline_nested_sync_trace(aconn, trace):
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ async with aconn.pipeline():
+ pass
+ await aconn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 2
+
+
+async def test_cursor_stream(aconn):
+ async with aconn.pipeline(), aconn.cursor() as cur:
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.stream("select 1").__anext__()
+
+
+async def test_server_cursor(aconn):
+ async with aconn.cursor(name="pipeline") as cur, aconn.pipeline():
+ with pytest.raises(psycopg.NotSupportedError):
+ await cur.execute("select 1")
+
+
+async def test_cannot_insert_multiple_commands(aconn):
+ with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)):
+ async with aconn.pipeline():
+ await aconn.execute("select 1; select 2")
+
+
+async def test_copy(aconn):
+ async with aconn.pipeline():
+ cur = aconn.cursor()
+ with pytest.raises(e.NotSupportedError):
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ await copy.read()
+
+
+async def test_pipeline_processed_at_exit(aconn):
+ async with aconn.cursor() as cur:
+ async with aconn.pipeline() as p:
+ await cur.execute("select 1")
+
+ assert len(p.result_queue) == 1
+
+ assert await cur.fetchone() == (1,)
+
+
+async def test_pipeline_errors_processed_at_exit(aconn):
+ await aconn.set_autocommit(True)
+ with pytest.raises(e.UndefinedTable):
+ async with aconn.pipeline():
+ await aconn.execute("select * from nosuchtable")
+ await aconn.execute("create table voila ()")
+ cur = await aconn.execute(
+ "select count(*) from pg_tables where tablename = %s", ("voila",)
+ )
+ (count,) = await cur.fetchone()
+ assert count == 0
+
+
+async def test_pipeline(aconn):
+ async with aconn.pipeline() as p:
+ c1 = aconn.cursor()
+ c2 = aconn.cursor()
+ await c1.execute("select 1")
+ await c2.execute("select 2")
+
+ assert len(p.result_queue) == 2
+
+ (r1,) = await c1.fetchone()
+ assert r1 == 1
+
+ (r2,) = await c2.fetchone()
+ assert r2 == 2
+
+
+async def test_autocommit(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline(), aconn.cursor() as c:
+ await c.execute("select 1")
+
+ (r,) = await c.fetchone()
+ assert r == 1
+
+
+async def test_pipeline_aborted(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline() as p:
+ c1 = await aconn.execute("select 1")
+ with pytest.raises(e.UndefinedTable):
+ await (await aconn.execute("select * from doesnotexist")).fetchone()
+ with pytest.raises(e.PipelineAborted):
+ await (await aconn.execute("select 'aborted'")).fetchone()
+ # Sync restore the connection in usable state.
+ await p.sync()
+ c2 = await aconn.execute("select 2")
+
+ (r,) = await c1.fetchone()
+ assert r == 1
+
+ (r,) = await c2.fetchone()
+ assert r == 2
+
+
+async def test_pipeline_commit_aborted(aconn):
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ async with aconn.pipeline():
+ await aconn.execute("select error")
+ await aconn.execute("create table voila ()")
+ await aconn.commit()
+
+
+async def test_sync_syncs_results(aconn):
+ async with aconn.pipeline() as p:
+ cur = await aconn.execute("select 1")
+ assert cur.statusmessage is None
+ await p.sync()
+ assert cur.statusmessage == "SELECT 1"
+
+
+async def test_sync_syncs_errors(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline() as p:
+ await aconn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ await p.sync()
+
+
+@pipeline_aborted
+async def test_errors_raised_on_commit(aconn):
+ async with aconn.pipeline():
+ await aconn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ await aconn.commit()
+ await aconn.rollback()
+ cur1 = await aconn.execute("select 1")
+ cur2 = await aconn.execute("select 2")
+
+ assert await cur1.fetchone() == (1,)
+ assert await cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+async def test_errors_raised_on_transaction_exit(aconn):
+ here = False
+ async with aconn.pipeline():
+ with pytest.raises(e.UndefinedTable):
+ async with aconn.transaction():
+ await aconn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = await aconn.execute("select 1")
+ assert here
+ cur2 = await aconn.execute("select 2")
+
+ assert await cur1.fetchone() == (1,)
+ assert await cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+async def test_errors_raised_on_nested_transaction_exit(aconn):
+ here = False
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ with pytest.raises(e.UndefinedTable):
+ async with aconn.transaction():
+ await aconn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = await aconn.execute("select 1")
+ assert here
+ cur2 = await aconn.execute("select 2")
+
+ assert await cur1.fetchone() == (1,)
+ assert await cur2.fetchone() == (2,)
+
+
+async def test_implicit_transaction(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+ await aconn.execute("select 'before'")
+ # Transaction is ACTIVE because previous command is not completed
+ # since we have not fetched its results.
+ assert aconn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE
+ # Upon entering the nested pipeline through "with transaction():", a
+ # sync() is emitted to restore the transaction state to IDLE, as
+ # expected to emit a BEGIN.
+ async with aconn.transaction():
+ await aconn.execute("select 'tx'")
+ cur = await aconn.execute("select 'after'")
+ assert await cur.fetchone() == ("after",)
+
+
+@pytest.mark.crdb_skip("deferrable")
+async def test_error_on_commit(aconn):
+ await aconn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ await aconn.commit()
+
+ async with aconn.pipeline():
+ await aconn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ await aconn.commit()
+ cur1 = await aconn.execute("select 1")
+ cur2 = await aconn.execute("select 2")
+
+ assert (await cur1.fetchone()) == (1,)
+ assert (await cur2.fetchone()) == (2,)
+
+
+async def test_fetch_no_result(aconn):
+ async with aconn.pipeline():
+ cur = aconn.cursor()
+ with pytest.raises(e.ProgrammingError):
+ await cur.fetchone()
+
+
+async def test_executemany(aconn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("drop table if exists execmanypipeline")
+ await aconn.execute(
+ "create unlogged table execmanypipeline ("
+ " id serial primary key, num integer)"
+ )
+ async with aconn.pipeline(), aconn.cursor() as cur:
+ await cur.executemany(
+ "insert into execmanypipeline(num) values (%s) returning num",
+ [(10,), (20,)],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert (await cur.fetchone()) == (10,)
+ assert cur.nextset()
+ assert (await cur.fetchone()) == (20,)
+ assert cur.nextset() is None
+
+
+async def test_executemany_no_returning(aconn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("drop table if exists execmanypipelinenoreturning")
+ await aconn.execute(
+ "create unlogged table execmanypipelinenoreturning ("
+ " id serial primary key, num integer)"
+ )
+ async with aconn.pipeline(), aconn.cursor() as cur:
+ await cur.executemany(
+ "insert into execmanypipelinenoreturning(num) values (%s)",
+ [(10,), (20,)],
+ returning=False,
+ )
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ await cur.fetchone()
+ assert cur.nextset() is None
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ await cur.fetchone()
+ assert cur.nextset() is None
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+async def test_executemany_trace(aconn, trace):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("create temp table trace (id int)")
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ await aconn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"]
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+async def test_executemany_trace_returning(aconn, trace):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("create temp table trace (id int)")
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ await cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ await cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ await aconn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"] * 3
+ assert items[-2].direction == "F" # last 2 items are F B
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+async def test_prepared(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ c1 = await aconn.execute("select %s::int", [10], prepare=True)
+ c2 = await aconn.execute(
+ "select count(*) from pg_prepared_statements where name != ''"
+ )
+
+ (r,) = await c1.fetchone()
+ assert r == 10
+
+ (r,) = await c2.fetchone()
+ assert r == 1
+
+
+async def test_auto_prepare(aconn):
+ aconn.prepared_threshold = 5
+ async with aconn.pipeline():
+ cursors = [
+ await aconn.execute(
+ "select count(*) from pg_prepared_statements where name != ''"
+ )
+ for i in range(10)
+ ]
+
+ assert len(aconn._prepared._names) == 1
+
+ res = [(await c.fetchone())[0] for c in cursors]
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_transaction(aconn):
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ cur = await aconn.execute("select 'tx'")
+
+ (r,) = await cur.fetchone()
+ assert r == "tx"
+
+ async with aconn.transaction():
+ cur = await aconn.execute("select 'rb'")
+ raise psycopg.Rollback()
+
+ (r,) = await cur.fetchone()
+ assert r == "rb"
+
+ assert not notices
+
+
+async def test_transaction_nested(aconn):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ outer = await aconn.execute("select 'outer'")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.transaction():
+ inner = await aconn.execute("select 'inner'")
+ 1 / 0
+
+ (r,) = await outer.fetchone()
+ assert r == "outer"
+ (r,) = await inner.fetchone()
+ assert r == "inner"
+
+
+async def test_transaction_nested_no_statement(aconn):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ async with aconn.transaction():
+ cur = await aconn.execute("select 1")
+
+ (r,) = await cur.fetchone()
+ assert r == 1
+
+
+async def test_outer_transaction(aconn):
+ async with aconn.transaction():
+ await aconn.execute("drop table if exists outertx")
+ async with aconn.transaction():
+ async with aconn.pipeline():
+ await aconn.execute("create table outertx as (select 1)")
+ cur = await aconn.execute("select * from outertx")
+ (r,) = await cur.fetchone()
+ assert r == 1
+ cur = await aconn.execute(
+ "select count(*) from pg_tables where tablename = 'outertx'"
+ )
+ assert (await cur.fetchone())[0] == 1
+
+
+async def test_outer_transaction_error(aconn):
+ async with aconn.transaction():
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ async with aconn.pipeline():
+ await aconn.execute("select error")
+ await aconn.execute("create table voila ()")
+
+
+async def test_rollback_explicit(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ with pytest.raises(e.DivisionByZero):
+ cur = await aconn.execute("select 1 / %s", [0])
+ await cur.fetchone()
+ await aconn.rollback()
+ await aconn.execute("select 1")
+
+
+async def test_rollback_transaction(aconn):
+ await aconn.set_autocommit(True)
+ with pytest.raises(e.DivisionByZero):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ cur = await aconn.execute("select 1 / %s", [0])
+ await cur.fetchone()
+ await aconn.execute("select 1")
+
+
+async def test_message_0x33(aconn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ cur = await aconn.execute("select 'test'")
+ assert (await cur.fetchone()) == ("test",)
+
+ assert not notices
+
+
+async def test_transaction_state_implicit_begin(aconn, trace):
+ # Regression test to ensure that the transaction state is correct after
+ # the implicit BEGIN statement (in non-autocommit mode).
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ await (await aconn.execute("select 'x'")).fetchone()
+ await aconn.execute("select 'y'")
+ assert not notices
+ assert [
+ e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0]
+ ] == [b' "" "BEGIN" 0']
+
+
+async def test_concurrency(aconn):
+ async with aconn.transaction():
+ await aconn.execute("drop table if exists pipeline_concurrency")
+ await aconn.execute("drop table if exists accessed")
+ async with aconn.transaction():
+ await aconn.execute(
+ "create unlogged table pipeline_concurrency ("
+ " id serial primary key,"
+ " value integer"
+ ")"
+ )
+ await aconn.execute("create unlogged table accessed as (select now() as value)")
+
+ async def update(value):
+ cur = await aconn.execute(
+ "insert into pipeline_concurrency(value) values (%s) returning value",
+ (value,),
+ )
+ await aconn.execute("update accessed set value = now()")
+ return cur
+
+ await aconn.set_autocommit(True)
+
+ (before,) = await (await aconn.execute("select value from accessed")).fetchone()
+
+ values = range(1, 10)
+ async with aconn.pipeline():
+ cursors = await asyncio.wait_for(
+ asyncio.gather(*[update(value) for value in values]),
+ timeout=len(values),
+ )
+
+ assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+
+ (s,) = await (
+ await aconn.execute("select sum(value) from pipeline_concurrency")
+ ).fetchone()
+ assert s == sum(values)
+ (after,) = await (await aconn.execute("select value from accessed")).fetchone()
+ assert after > before
diff --git a/tests/test_prepared.py b/tests/test_prepared.py
new file mode 100644
index 0000000..56c580a
--- /dev/null
+++ b/tests/test_prepared.py
@@ -0,0 +1,277 @@
+"""
+Prepared statements tests
+"""
+
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg.rows import namedtuple_row
+
+
+@pytest.mark.parametrize("value", [None, 0, 3])
+def test_prepare_threshold_init(conn_cls, dsn, value):
+ with conn_cls.connect(dsn, prepare_threshold=value) as conn:
+ assert conn.prepare_threshold == value
+
+
+def test_dont_prepare(conn):
+ cur = conn.cursor()
+ for i in range(10):
+ cur.execute("select %s::int", [i], prepare=False)
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+def test_do_prepare(conn):
+ cur = conn.cursor()
+ cur.execute("select %s::int", [10], prepare=True)
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+
+
+def test_auto_prepare(conn):
+ res = []
+ for i in range(10):
+ conn.execute("select %s::int", [0])
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+def test_dont_prepare_conn(conn):
+ for i in range(10):
+ conn.execute("select %s::int", [i], prepare=False)
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+def test_do_prepare_conn(conn):
+ conn.execute("select %s::int", [10], prepare=True)
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+
+
+def test_auto_prepare_conn(conn):
+ res = []
+ for i in range(10):
+ conn.execute("select %s", [0])
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+def test_prepare_disable(conn):
+ conn.prepare_threshold = None
+ res = []
+ for i in range(10):
+ conn.execute("select %s", [0])
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+ assert not conn._prepared._names
+ assert not conn._prepared._counts
+
+
+def test_no_prepare_multi(conn):
+ res = []
+ for i in range(10):
+ conn.execute("select 1; select 2")
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+
+
+def test_no_prepare_multi_with_drop(conn):
+ conn.execute("select 1", prepare=True)
+
+ for i in range(10):
+ conn.execute("drop table if exists noprep; create table noprep()")
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+def test_no_prepare_error(conn):
+ conn.autocommit = True
+ for i in range(10):
+ with pytest.raises(conn.ProgrammingError):
+ conn.execute("select wat")
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_no_prepare ()",
+ pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")),
+ "set timezone = utc",
+ "select num from prepared_test",
+ "insert into prepared_test (num) values (1)",
+ "update prepared_test set num = num * 2",
+ "delete from prepared_test where num > 10",
+ ],
+)
+def test_misc_statement(conn, query):
+ conn.execute("create table prepared_test (num int)", prepare=False)
+ conn.prepare_threshold = 0
+ conn.execute(query)
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+
+
+def test_params_types(conn):
+ conn.execute(
+ "select %s, %s, %s",
+ [dt.date(2020, 12, 10), 42, Decimal(42)],
+ prepare=True,
+ )
+ stmts = get_prepared_statements(conn)
+ want = [stmt.parameter_types for stmt in stmts]
+ assert want == [["date", "smallint", "numeric"]]
+
+
+def test_evict_lru(conn):
+ conn.prepared_max = 5
+ for i in range(10):
+ conn.execute("select 'a'")
+ conn.execute(f"select {i}")
+
+ assert len(conn._prepared._names) == 1
+ assert conn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
+ for i in [9, 8, 7, 6]:
+ assert conn._prepared._counts[f"select {i}".encode(), ()] == 1
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+ assert stmts[0].statement == "select 'a'"
+
+
+def test_evict_lru_deallocate(conn):
+ conn.prepared_max = 5
+ conn.prepare_threshold = 0
+ for i in range(10):
+ conn.execute("select 'a'")
+ conn.execute(f"select {i}")
+
+ assert len(conn._prepared._names) == 5
+ for j in [9, 8, 7, 6, "'a'"]:
+ name = conn._prepared._names[f"select {j}".encode(), ()]
+ assert name.startswith(b"_pg3_")
+
+ stmts = get_prepared_statements(conn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.statement for stmt in stmts]
+ assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]
+
+
+def test_different_types(conn):
+ conn.prepare_threshold = 0
+ conn.execute("select %s", [None])
+ conn.execute("select %s", [dt.date(2000, 1, 1)])
+ conn.execute("select %s", [42])
+ conn.execute("select %s", [41])
+ conn.execute("select %s", [dt.date(2000, 1, 2)])
+
+ stmts = get_prepared_statements(conn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["text"], ["date"], ["smallint"]]
+
+
+def test_untyped_json(conn):
+ conn.prepare_threshold = 1
+ conn.execute("create table testjson(data jsonb)")
+
+ for i in range(2):
+ conn.execute("insert into testjson (data) values (%s)", ["{}"])
+
+ stmts = get_prepared_statements(conn)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["jsonb"]]
+
+
+def test_change_type_execute(conn):
+ conn.prepare_threshold = 0
+ for i in range(3):
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+ conn.rollback()
+
+
+def test_change_type_executemany(conn):
+ for i in range(3):
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().executemany(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}],
+ )
+ conn.rollback()
+
+
+@pytest.mark.crdb("skip", reason="can't re-create a type")
+def test_change_type(conn):
+ conn.prepare_threshold = 0
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+ conn.execute("DROP TABLE preptable")
+ conn.execute("DROP TYPE prepenum")
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 3
+
+
+def test_change_type_savepoint(conn):
+ conn.prepare_threshold = 0
+ with conn.transaction():
+ for i in range(3):
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+ raise ZeroDivisionError()
+
+
+def get_prepared_statements(conn):
+ cur = conn.cursor(row_factory=namedtuple_row)
+ cur.execute(
+ # CRDB has 'PREPARE name AS' in the statement.
+ r"""
+select name,
+ regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
+ prepare_time,
+ parameter_types
+from pg_prepared_statements
+where name != ''
+ """,
+ prepare=False,
+ )
+ return cur.fetchall()
diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py
new file mode 100644
index 0000000..84d948f
--- /dev/null
+++ b/tests/test_prepared_async.py
@@ -0,0 +1,207 @@
+"""
+Prepared statements tests on async connections
+"""
+
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg.rows import namedtuple_row
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.parametrize("value", [None, 0, 3])
+async def test_prepare_threshold_init(aconn_cls, dsn, value):
+ async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn:
+ assert conn.prepare_threshold == value
+
+
+async def test_dont_prepare(aconn):
+ cur = aconn.cursor()
+ for i in range(10):
+ await cur.execute("select %s::int", [i], prepare=False)
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+async def test_do_prepare(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s::int", [10], prepare=True)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_auto_prepare(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s::int", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_dont_prepare_conn(aconn):
+ for i in range(10):
+ await aconn.execute("select %s::int", [i], prepare=False)
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+async def test_do_prepare_conn(aconn):
+ await aconn.execute("select %s::int", [10], prepare=True)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_auto_prepare_conn(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_prepare_disable(aconn):
+ aconn.prepare_threshold = None
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+ assert not aconn._prepared._names
+ assert not aconn._prepared._counts
+
+
+async def test_no_prepare_multi(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select 1; select 2")
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+
+
+async def test_no_prepare_error(aconn):
+ await aconn.set_autocommit(True)
+ for i in range(10):
+ with pytest.raises(aconn.ProgrammingError):
+ await aconn.execute("select wat")
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_no_prepare ()",
+ pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")),
+ "set timezone = utc",
+ "select num from prepared_test",
+ "insert into prepared_test (num) values (1)",
+ "update prepared_test set num = num * 2",
+ "delete from prepared_test where num > 10",
+ ],
+)
+async def test_misc_statement(aconn, query):
+ await aconn.execute("create table prepared_test (num int)", prepare=False)
+ aconn.prepare_threshold = 0
+ await aconn.execute(query)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_params_types(aconn):
+ await aconn.execute(
+ "select %s, %s, %s",
+ [dt.date(2020, 12, 10), 42, Decimal(42)],
+ prepare=True,
+ )
+ stmts = await get_prepared_statements(aconn)
+ want = [stmt.parameter_types for stmt in stmts]
+ assert want == [["date", "smallint", "numeric"]]
+
+
+async def test_evict_lru(aconn):
+ aconn.prepared_max = 5
+ for i in range(10):
+ await aconn.execute("select 'a'")
+ await aconn.execute(f"select {i}")
+
+ assert len(aconn._prepared._names) == 1
+ assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
+ for i in [9, 8, 7, 6]:
+ assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+ assert stmts[0].statement == "select 'a'"
+
+
+async def test_evict_lru_deallocate(aconn):
+ aconn.prepared_max = 5
+ aconn.prepare_threshold = 0
+ for i in range(10):
+ await aconn.execute("select 'a'")
+ await aconn.execute(f"select {i}")
+
+ assert len(aconn._prepared._names) == 5
+ for j in [9, 8, 7, 6, "'a'"]:
+ name = aconn._prepared._names[f"select {j}".encode(), ()]
+ assert name.startswith(b"_pg3_")
+
+ stmts = await get_prepared_statements(aconn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.statement for stmt in stmts]
+ assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]
+
+
+async def test_different_types(aconn):
+ aconn.prepare_threshold = 0
+ await aconn.execute("select %s", [None])
+ await aconn.execute("select %s", [dt.date(2000, 1, 1)])
+ await aconn.execute("select %s", [42])
+ await aconn.execute("select %s", [41])
+ await aconn.execute("select %s", [dt.date(2000, 1, 2)])
+
+ stmts = await get_prepared_statements(aconn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["text"], ["date"], ["smallint"]]
+
+
+async def test_untyped_json(aconn):
+ aconn.prepare_threshold = 1
+ await aconn.execute("create table testjson(data jsonb)")
+ for i in range(2):
+ await aconn.execute("insert into testjson (data) values (%s)", ["{}"])
+
+ stmts = await get_prepared_statements(aconn)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["jsonb"]]
+
+
+async def get_prepared_statements(aconn):
+ cur = aconn.cursor(row_factory=namedtuple_row)
+ await cur.execute(
+ r"""
+select name,
+ regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
+ prepare_time,
+ parameter_types
+from pg_prepared_statements
+where name != ''
+ """,
+ prepare=False,
+ )
+ return await cur.fetchall()
diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py
new file mode 100644
index 0000000..82a5d73
--- /dev/null
+++ b/tests/test_psycopg_dbapi20.py
@@ -0,0 +1,164 @@
+import pytest
+import datetime as dt
+from typing import Any, Dict
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from . import dbapi20
+from . import dbapi20_tpc
+
+
+@pytest.fixture(scope="class")
+def with_dsn(request, session_dsn):
+ request.cls.connect_args = (session_dsn,)
+
+
+@pytest.mark.usefixtures("with_dsn")
+class PsycopgTests(dbapi20.DatabaseAPI20Test):
+ driver = psycopg
+ # connect_args = () # set by the fixture
+ connect_kw_args: Dict[str, Any] = {}
+
+ def test_nextset(self):
+ # tested elsewhere
+ pass
+
+ def test_setoutputsize(self):
+ # no-op
+ pass
+
+
+@pytest.mark.usefixtures("tpc")
+@pytest.mark.usefixtures("with_dsn")
+class PsycopgTPCTests(dbapi20_tpc.TwoPhaseCommitTests):
+ driver = psycopg
+ connect_args = () # set by the fixture
+
+ def connect(self):
+ return psycopg.connect(*self.connect_args)
+
+
+# Shut up warnings
+PsycopgTests.failUnless = PsycopgTests.assertTrue
+PsycopgTPCTests.assertEquals = PsycopgTPCTests.assertEqual
+
+
+@pytest.mark.parametrize(
+ "typename, singleton",
+ [
+ ("bytea", "BINARY"),
+ ("date", "DATETIME"),
+ ("timestamp without time zone", "DATETIME"),
+ ("timestamp with time zone", "DATETIME"),
+ ("time without time zone", "DATETIME"),
+ ("time with time zone", "DATETIME"),
+ ("interval", "DATETIME"),
+ ("integer", "NUMBER"),
+ ("smallint", "NUMBER"),
+ ("bigint", "NUMBER"),
+ ("real", "NUMBER"),
+ ("double precision", "NUMBER"),
+ ("numeric", "NUMBER"),
+ ("decimal", "NUMBER"),
+ ("oid", "ROWID"),
+ ("varchar", "STRING"),
+ ("char", "STRING"),
+ ("text", "STRING"),
+ ],
+)
+def test_singletons(conn, typename, singleton):
+ singleton = getattr(psycopg, singleton)
+ cur = conn.cursor()
+ cur.execute(f"select null::{typename}")
+ oid = cur.description[0].type_code
+ assert singleton == oid
+ assert oid == singleton
+ assert singleton != oid + 10000
+ assert oid + 10000 != singleton
+
+
+@pytest.mark.parametrize(
+ "ticks, want",
+ [
+ (0, "1970-01-01T00:00:00.000000+0000"),
+ (1273173119.99992, "2010-05-06T14:11:59.999920-0500"),
+ ],
+)
+def test_timestamp_from_ticks(ticks, want):
+ s = psycopg.TimestampFromTicks(ticks)
+ want = dt.datetime.strptime(want, "%Y-%m-%dT%H:%M:%S.%f%z")
+ assert s == want
+
+
+@pytest.mark.parametrize(
+ "ticks, want",
+ [
+ (0, "1970-01-01"),
+ # Returned date is local
+ (1273173119.99992, ["2010-05-06", "2010-05-07"]),
+ ],
+)
+def test_date_from_ticks(ticks, want):
+ s = psycopg.DateFromTicks(ticks)
+ if isinstance(want, str):
+ want = [want]
+ want = [dt.datetime.strptime(w, "%Y-%m-%d").date() for w in want]
+ assert s in want
+
+
+@pytest.mark.parametrize(
+ "ticks, want",
+ [(0, "00:00:00.000000"), (1273173119.99992, "00:11:59.999920")],
+)
+def test_time_from_ticks(ticks, want):
+ s = psycopg.TimeFromTicks(ticks)
+ want = dt.datetime.strptime(want, "%H:%M:%S.%f").time()
+ assert s.replace(hour=0) == want
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want",
+ [
+ ((), {}, ""),
+ (("",), {}, ""),
+ (("host=foo user=bar",), {}, "host=foo user=bar"),
+ (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+ (
+ ("host=foo port=5432",),
+ {"host": "qux", "user": "joe"},
+ "host=qux user=joe port=5432",
+ ),
+ (("host=foo",), {"user": None}, "host=foo"),
+ ],
+)
+def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
+ the_conninfo: str
+
+ def fake_connect(conninfo):
+ nonlocal the_conninfo
+ the_conninfo = conninfo
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ conn = psycopg.connect(*args, **kwargs)
+ assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ conn.close()
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, exctype",
+ [
+ (("host=foo", "host=bar"), {}, TypeError),
+ (("", ""), {}, TypeError),
+ ((), {"nosuchparam": 42}, psycopg.ProgrammingError),
+ ],
+)
+def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype):
+ def fake_connect(conninfo):
+ return pgconn
+ yield
+
+ with pytest.raises(exctype):
+ psycopg.connect(*args, **kwargs)
diff --git a/tests/test_query.py b/tests/test_query.py
new file mode 100644
index 0000000..7263a80
--- /dev/null
+++ b/tests/test_query.py
@@ -0,0 +1,162 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg.adapt import Transformer, PyFormat
+from psycopg._queries import PostgresQuery, _split_query
+
+
+@pytest.mark.parametrize(
+ "input, want",
+ [
+ (b"", [(b"", 0, PyFormat.AUTO)]),
+ (b"foo bar", [(b"foo bar", 0, PyFormat.AUTO)]),
+ (b"foo %% bar", [(b"foo % bar", 0, PyFormat.AUTO)]),
+ (b"%s", [(b"", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]),
+ (b"%s foo", [(b"", 0, PyFormat.AUTO), (b" foo", 0, PyFormat.AUTO)]),
+ (b"%b foo", [(b"", 0, PyFormat.BINARY), (b" foo", 0, PyFormat.AUTO)]),
+ (b"foo %s", [(b"foo ", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]),
+ (
+ b"foo %%%s bar",
+ [(b"foo %", 0, PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)],
+ ),
+ (
+ b"foo %(name)s bar",
+ [(b"foo ", "name", PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)],
+ ),
+ (
+ b"foo %(name)s %(name)b bar",
+ [
+ (b"foo ", "name", PyFormat.AUTO),
+ (b" ", "name", PyFormat.BINARY),
+ (b" bar", 0, PyFormat.AUTO),
+ ],
+ ),
+ (
+ b"foo %s%b bar %s baz",
+ [
+ (b"foo ", 0, PyFormat.AUTO),
+ (b"", 1, PyFormat.BINARY),
+ (b" bar ", 2, PyFormat.AUTO),
+ (b" baz", 0, PyFormat.AUTO),
+ ],
+ ),
+ ],
+)
+def test_split_query(input, want):
+ assert _split_query(input) == want
+
+
+@pytest.mark.parametrize(
+ "input",
+ [
+ b"foo %d bar",
+ b"foo % bar",
+ b"foo %%% bar",
+ b"foo %(foo)d bar",
+ b"foo %(foo)s bar %s baz",
+ b"foo %(foo) bar",
+ b"foo %(foo bar",
+ b"3%2",
+ ],
+)
+def test_split_query_bad(input):
+ with pytest.raises(psycopg.ProgrammingError):
+ _split_query(input)
+
+
+@pytest.mark.parametrize(
+ "query, params, want, wformats, wparams",
+ [
+ (b"", None, b"", None, None),
+ (b"", [], b"", [], []),
+ (b"%%", [], b"%", [], []),
+ (b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]),
+ (
+ b"%t %% %t",
+ (1, 2),
+ b"$1 % $2",
+ [pq.Format.TEXT, pq.Format.TEXT],
+ [b"1", b"2"],
+ ),
+ (
+ b"%t %% %t",
+ ("a", 2),
+ b"$1 % $2",
+ [pq.Format.TEXT, pq.Format.TEXT],
+ [b"a", b"2"],
+ ),
+ ],
+)
+def test_pg_query_seq(query, params, want, wformats, wparams):
+ pq = PostgresQuery(Transformer())
+ pq.convert(query, params)
+ assert pq.query == want
+ assert pq.formats == wformats
+ assert pq.params == wparams
+
+
+@pytest.mark.parametrize(
+ "query, params, want, wformats, wparams",
+ [
+ (b"", {}, b"", [], []),
+ (b"hello %%", {"a": 1}, b"hello %", [], []),
+ (
+ b"select %(hello)t",
+ {"hello": 1, "world": 2},
+ b"select $1",
+ [pq.Format.TEXT],
+ [b"1"],
+ ),
+ (
+ b"select %(hi)s %(there)s %(hi)s",
+ {"hi": 0, "there": "a"},
+ b"select $1 $2 $1",
+ [pq.Format.BINARY, pq.Format.TEXT],
+ [b"\x00" * 2, b"a"],
+ ),
+ ],
+)
+def test_pg_query_map(query, params, want, wformats, wparams):
+ pq = PostgresQuery(Transformer())
+ pq.convert(query, params)
+ assert pq.query == want
+ assert pq.formats == wformats
+ assert pq.params == wparams
+
+
+@pytest.mark.parametrize(
+ "query, params",
+ [
+ (b"select %s", {"a": 1}),
+ (b"select %(name)s", [1]),
+ (b"select %s", "a"),
+ (b"select %s", 1),
+ (b"select %s", b"a"),
+ (b"select %s", set()),
+ ],
+)
+def test_pq_query_badtype(query, params):
+ pq = PostgresQuery(Transformer())
+ with pytest.raises(TypeError):
+ pq.convert(query, params)
+
+
+@pytest.mark.parametrize(
+ "query, params",
+ [
+ (b"", [1]),
+ (b"%s", []),
+ (b"%%", [1]),
+ (b"$1", [1]),
+ (b"select %(", {"a": 1}),
+ (b"select %(a", {"a": 1}),
+ (b"select %(a)", {"a": 1}),
+ (b"select %s %(hi)s", [1]),
+ (b"select %(hi)s %(hi)b", {"hi": 1}),
+ ],
+)
+def test_pq_query_badprog(query, params):
+ pq = PostgresQuery(Transformer())
+ with pytest.raises(psycopg.ProgrammingError):
+ pq.convert(query, params)
diff --git a/tests/test_rows.py b/tests/test_rows.py
new file mode 100644
index 0000000..5165b80
--- /dev/null
+++ b/tests/test_rows.py
@@ -0,0 +1,167 @@
+import pytest
+
+import psycopg
+from psycopg import rows
+
+from .utils import eur
+
+
+def test_tuple_row(conn):
+ conn.row_factory = rows.dict_row
+ assert conn.execute("select 1 as a").fetchone() == {"a": 1}
+ cur = conn.cursor(row_factory=rows.tuple_row)
+ row = cur.execute("select 1 as a").fetchone()
+ assert row == (1,)
+ assert type(row) is tuple
+ assert cur._make_row is tuple
+
+
+def test_dict_row(conn):
+ cur = conn.cursor(row_factory=rows.dict_row)
+ cur.execute("select 'bob' as name, 3 as id")
+ assert cur.fetchall() == [{"name": "bob", "id": 3}]
+
+ cur.execute("select 'a' as letter; select 1 as number")
+ assert cur.fetchall() == [{"letter": "a"}]
+ assert cur.nextset()
+ assert cur.fetchall() == [{"number": 1}]
+ assert not cur.nextset()
+
+
+def test_namedtuple_row(conn):
+ rows._make_nt.cache_clear()
+ cur = conn.cursor(row_factory=rows.namedtuple_row)
+ cur.execute("select 'bob' as name, 3 as id")
+ (person1,) = cur.fetchall()
+ assert f"{person1.name} {person1.id}" == "bob 3"
+
+ ci1 = rows._make_nt.cache_info()
+ assert ci1.hits == 0 and ci1.misses == 1
+
+ cur.execute("select 'alice' as name, 1 as id")
+ (person2,) = cur.fetchall()
+ assert type(person2) is type(person1)
+
+ ci2 = rows._make_nt.cache_info()
+ assert ci2.hits == 1 and ci2.misses == 1
+
+ cur.execute("select 'foo', 1 as id")
+ (r0,) = cur.fetchall()
+ assert r0.f_column_ == "foo"
+ assert r0.id == 1
+
+ cur.execute("select 'a' as letter; select 1 as number")
+ (r1,) = cur.fetchall()
+ assert r1.letter == "a"
+ assert cur.nextset()
+ (r2,) = cur.fetchall()
+ assert r2.number == 1
+ assert not cur.nextset()
+ assert type(r1) is not type(r2)
+
+ cur.execute(f'select 1 as üåäö, 2 as _, 3 as "123", 4 as "a-b", 5 as "{eur}eur"')
+ (r3,) = cur.fetchall()
+ assert r3.üåäö == 1
+ assert r3.f_ == 2
+ assert r3.f123 == 3
+ assert r3.a_b == 4
+ assert r3.f_eur == 5
+
+
+def test_class_row(conn):
+ cur = conn.cursor(row_factory=rows.class_row(Person))
+ cur.execute("select 'John' as first, 'Doe' as last")
+ (p,) = cur.fetchall()
+ assert isinstance(p, Person)
+ assert p.first == "John"
+ assert p.last == "Doe"
+ assert p.age is None
+
+ for query in (
+ "select 'John' as first",
+ "select 'John' as first, 'Doe' as last, 42 as wat",
+ ):
+ cur.execute(query)
+ with pytest.raises(TypeError):
+ cur.fetchone()
+
+
+def test_args_row(conn):
+ cur = conn.cursor(row_factory=rows.args_row(argf))
+ cur.execute("select 'John' as first, 'Doe' as last")
+ assert cur.fetchone() == "JohnDoe"
+
+
+def test_kwargs_row(conn):
+ cur = conn.cursor(row_factory=rows.kwargs_row(kwargf))
+ cur.execute("select 'John' as first, 'Doe' as last")
+ (p,) = cur.fetchall()
+ assert isinstance(p, Person)
+ assert p.first == "John"
+ assert p.last == "Doe"
+ assert p.age == 42
+
+
+@pytest.mark.parametrize(
+ "factory",
+ "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(),
+)
+def test_no_result(factory, conn):
+ cur = conn.cursor(row_factory=factory_from_name(factory))
+ cur.execute("reset search_path")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+
+@pytest.mark.crdb_skip("no col query")
+@pytest.mark.parametrize(
+ "factory", "tuple_row dict_row namedtuple_row args_row".split()
+)
+def test_no_column(factory, conn):
+ cur = conn.cursor(row_factory=factory_from_name(factory))
+ cur.execute("select")
+ recs = cur.fetchall()
+ assert len(recs) == 1
+ assert not recs[0]
+
+
+@pytest.mark.crdb("skip")
+def test_no_column_class_row(conn):
+ class Empty:
+ def __init__(self, x=10, y=20):
+ self.x = x
+ self.y = y
+
+ cur = conn.cursor(row_factory=rows.class_row(Empty))
+ cur.execute("select")
+ x = cur.fetchone()
+ assert isinstance(x, Empty)
+ assert x.x == 10
+ assert x.y == 20
+
+
+def factory_from_name(name):
+ factory = getattr(rows, name)
+ if factory is rows.class_row:
+ factory = factory(Person)
+ if factory is rows.args_row:
+ factory = factory(argf)
+ if factory is rows.kwargs_row:
+ factory = factory(argf)
+
+ return factory
+
+
+class Person:
+ def __init__(self, first, last, age=None):
+ self.first = first
+ self.last = last
+ self.age = age
+
+
+def argf(*args):
+ return "".join(map(str, args))
+
+
+def kwargf(**kwargs):
+ return Person(**kwargs, age=42)
diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py
new file mode 100644
index 0000000..f7b6c8e
--- /dev/null
+++ b/tests/test_server_cursor.py
@@ -0,0 +1,525 @@
+import pytest
+
+import psycopg
+from psycopg import rows, errors as e
+from psycopg.pq import Format
+
+pytestmark = pytest.mark.crdb_skip("server-side cursor")
+
+
+def test_init_row_factory(conn):
+ with psycopg.ServerCursor(conn, "foo") as cur:
+ assert cur.name == "foo"
+ assert cur.connection is conn
+ assert cur.row_factory is conn.row_factory
+
+ conn.row_factory = rows.dict_row
+
+ with psycopg.ServerCursor(conn, "bar") as cur:
+ assert cur.name == "bar"
+ assert cur.row_factory is rows.dict_row # type: ignore
+
+ with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur:
+ assert cur.name == "baz"
+ assert cur.row_factory is rows.namedtuple_row # type: ignore
+
+
+def test_init_params(conn):
+ with psycopg.ServerCursor(conn, "foo") as cur:
+ assert cur.scrollable is None
+ assert cur.withhold is False
+
+ with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur:
+ assert cur.scrollable is False
+ assert cur.withhold is True
+
+
+@pytest.mark.crdb_skip("cursor invalid name")
+def test_funny_name(conn):
+ cur = conn.cursor("1-2-3")
+ cur.execute("select generate_series(1, 3) as bar")
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.name == "1-2-3"
+ cur.close()
+
+
+def test_repr(conn):
+ cur = conn.cursor("my-name")
+ assert "psycopg.ServerCursor" in str(cur)
+ assert "my-name" in repr(cur)
+ cur.close()
+
+
+def test_connection(conn):
+ cur = conn.cursor("foo")
+ assert cur.connection is conn
+ cur.close()
+
+
+def test_description(conn):
+ cur = conn.cursor("foo")
+ assert cur.name == "foo"
+ cur.execute("select generate_series(1, 10)::int4 as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0
+ cur.close()
+
+
+def test_format(conn):
+ cur = conn.cursor("foo")
+ assert cur.format == Format.TEXT
+ cur.close()
+
+ cur = conn.cursor("foo", binary=True)
+ assert cur.format == Format.BINARY
+ cur.close()
+
+
+def test_query_params(conn):
+ with conn.cursor("foo") as cur:
+ assert cur._query is None
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur._query
+ assert b"declare" in cur._query.query.lower()
+ assert b"(1, $1)" in cur._query.query.lower()
+ assert cur._query.params == [bytes([0, 3])] # 3 as binary int2
+
+
+def test_binary_cursor_execute(conn):
+ cur = conn.cursor("foo", binary=True)
+ cur.execute("select generate_series(1, 2)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+ cur.close()
+
+
+def test_execute_binary(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 2)::int4", binary=True)
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+
+ cur.execute("select generate_series(1, 1)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ cur.close()
+
+
+def test_binary_cursor_text_override(conn):
+ cur = conn.cursor("foo", binary=True)
+ cur.execute("select generate_series(1, 2)", binary=False)
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"2"
+
+ cur.execute("select generate_series(1, 2)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ cur.close()
+
+
+def test_close(conn, recwarn):
+ if conn.info.transaction_status == conn.TransactionStatus.INTRANS:
+ # connection dirty from previous failure
+ conn.execute("close foo")
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 10) as bar")
+ cur.close()
+ assert cur.closed
+
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_idempotent(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select 1")
+ cur.fetchall()
+ cur.close()
+ cur.close()
+
+
+def test_close_broken_conn(conn):
+ cur = conn.cursor("foo")
+ conn.close()
+ cur.close()
+ assert cur.closed
+
+
+def test_cursor_close_fetchone(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ for _ in range(5):
+ cur.fetchone()
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchone()
+
+
+def test_cursor_close_fetchmany(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchmany(2)) == 2
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchmany(2)
+
+
+def test_cursor_close_fetchall(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchall()) == 10
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchall()
+
+
+def test_close_noop(conn, recwarn):
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.close()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_on_error(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select 1")
+ with pytest.raises(e.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+ cur.close()
+
+
+def test_pgresult(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert cur.pgresult
+ cur.close()
+ assert not cur.pgresult
+
+
+def test_context(conn, recwarn):
+ recwarn.clear()
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, 10) as bar")
+
+ assert cur.closed
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_no_clobber(conn):
+ with pytest.raises(e.DivisionByZero):
+ with conn.cursor("foo") as cur:
+ cur.execute("select 1 / %s", (0,))
+ cur.fetchall()
+
+
+def test_warn_close(conn, recwarn):
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 10) as bar")
+ del cur
+ assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+def test_execute_reuse(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as foo", (3,))
+ assert cur.fetchone() == (1,)
+
+ cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
+ assert cur.fetchone() == ("hello", "world")
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["text"].oid
+ assert cur.description[1].name == "baz"
+
+
+@pytest.mark.parametrize(
+ "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"]
+)
+def test_execute_error(conn, stmt):
+ cur = conn.cursor("foo")
+ with pytest.raises(e.ProgrammingError):
+ cur.execute(stmt)
+ cur.close()
+
+
+def test_executemany(conn):
+ cur = conn.cursor("foo")
+ with pytest.raises(e.NotSupportedError):
+ cur.executemany("select %s", [(1,), (2,)])
+ cur.close()
+
+
+def test_fetchone(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchone() == (2,)
+ assert cur.fetchone() is None
+
+
+def test_fetchmany(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert cur.fetchone() == (4,)
+ assert cur.fetchmany(3) == [(5,)]
+ assert cur.fetchmany(3) == []
+
+
+def test_fetchall(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.fetchall() == []
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchall() == [(2,), (3,)]
+ assert cur.fetchall() == []
+
+
+def test_nextset(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert not cur.nextset()
+
+
+def test_no_result(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar where false", (3,))
+ assert len(cur.description) == 1
+ assert cur.fetchall() == []
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_standard_row_factory(conn, row_factory):
+ if row_factory == "tuple_row":
+ getter = lambda r: r[0] # noqa: E731
+ elif row_factory == "dict_row":
+ getter = lambda r: r["bar"] # noqa: E731
+ elif row_factory == "namedtuple_row":
+ getter = lambda r: r.bar # noqa: E731
+ else:
+ assert False, row_factory
+
+ row_factory = getattr(rows, row_factory)
+ with conn.cursor("foo", row_factory=row_factory) as cur:
+ cur.execute("select generate_series(1, 5) as bar")
+ assert getter(cur.fetchone()) == 1
+ assert list(map(getter, cur.fetchmany(2))) == [2, 3]
+ assert list(map(getter, cur.fetchall())) == [4, 5]
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_row_factory(conn):
+ n = 0
+
+ def my_row_factory(cur):
+ nonlocal n
+ n += 1
+ return lambda values: [n] + [-v for v in values]
+
+ cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+ cur.execute("select generate_series(1, 3) as x")
+ recs = cur.fetchall()
+ cur.scroll(0, "absolute")
+ while True:
+ rec = cur.fetchone()
+ if not rec:
+ break
+ recs.append(rec)
+ assert recs == [[1, -1], [1, -2], [1, -3]] * 2
+
+ cur.scroll(0, "absolute")
+ cur.row_factory = rows.dict_row
+ assert cur.fetchone() == {"x": 1}
+ cur.close()
+
+
+def test_rownumber(conn):
+ cur = conn.cursor("foo")
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ cur.fetchall()
+ assert cur.rownumber == 42
+ cur.close()
+
+
+def test_iter(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = list(cur)
+ assert recs == [(1,), (2,), (3,)]
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ recs = list(cur)
+ assert recs == [(2,), (3,)]
+
+
+def test_iter_rownumber(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ for row in cur:
+ assert cur.rownumber == row[0]
+
+
+def test_itersize(conn, commands):
+ with conn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ commands.popall() # flush begin and other noise
+
+ list(cur)
+ cmds = commands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert "fetch forward 2" in cmd.lower()
+
+
+def test_cant_scroll_by_default(conn):
+ cur = conn.cursor("tmp")
+ assert cur.scrollable is None
+ with pytest.raises(e.ProgrammingError):
+ cur.scroll(0)
+ cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_scroll(conn):
+ cur = conn.cursor("tmp", scrollable=True)
+ cur.execute("select generate_series(0,9)")
+ cur.scroll(2)
+ assert cur.fetchone() == (2,)
+ cur.scroll(2)
+ assert cur.fetchone() == (5,)
+ cur.scroll(2, mode="relative")
+ assert cur.fetchone() == (8,)
+ cur.scroll(9, mode="absolute")
+ assert cur.fetchone() == (9,)
+
+ with pytest.raises(ValueError):
+ cur.scroll(9, mode="wat")
+ cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_scrollable(conn):
+ curs = conn.cursor("foo", scrollable=True)
+ assert curs.scrollable is True
+ curs.execute("select generate_series(0, 5)")
+ curs.scroll(5)
+ for i in range(4, -1, -1):
+ curs.scroll(-1)
+ assert i == curs.fetchone()[0]
+ curs.scroll(-1)
+ curs.close()
+
+
+def test_non_scrollable(conn):
+ curs = conn.cursor("foo", scrollable=False)
+ assert curs.scrollable is False
+ curs.execute("select generate_series(0, 5)")
+ curs.scroll(5)
+ with pytest.raises(e.OperationalError):
+ curs.scroll(-1)
+ curs.close()
+
+
+@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
+def test_no_hold(conn, kwargs):
+ with conn.cursor("foo", **kwargs) as curs:
+ assert curs.withhold is False
+ curs.execute("select generate_series(0, 2)")
+ assert curs.fetchone() == (0,)
+ conn.commit()
+ with pytest.raises(e.InvalidCursorName):
+ curs.fetchone()
+
+
+@pytest.mark.crdb_skip("cursor with hold")
+def test_hold(conn):
+ with conn.cursor("foo", withhold=True) as curs:
+ assert curs.withhold is True
+ curs.execute("select generate_series(0, 5)")
+ assert curs.fetchone() == (0,)
+ conn.commit()
+ assert curs.fetchone() == (1,)
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+def test_steal_cursor(conn, row_factory):
+ cur1 = conn.cursor()
+ cur1.execute("declare test cursor for select generate_series(1, 6) as s")
+
+ cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory))
+ # can call fetch without execute
+ rec = cur2.fetchone()
+ assert rec == (1,)
+ if row_factory == "namedtuple_row":
+ assert rec.s == 1
+ assert cur2.fetchmany(3) == [(2,), (3,), (4,)]
+ assert cur2.fetchall() == [(5,), (6,)]
+ cur2.close()
+
+
+def test_stolen_cursor_close(conn):
+ cur1 = conn.cursor()
+ cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = conn.cursor("test")
+ cur2.close()
+
+ cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = conn.cursor("test")
+ cur2.close()
diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py
new file mode 100644
index 0000000..21b4345
--- /dev/null
+++ b/tests/test_server_cursor_async.py
@@ -0,0 +1,543 @@
+import pytest
+
+import psycopg
+from psycopg import rows, errors as e
+from psycopg.pq import Format
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("server-side cursor"),
+]
+
+
+async def test_init_row_factory(aconn):
+ async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+ assert cur.name == "foo"
+ assert cur.connection is aconn
+ assert cur.row_factory is aconn.row_factory
+
+ aconn.row_factory = rows.dict_row
+
+ async with psycopg.AsyncServerCursor(aconn, "bar") as cur:
+ assert cur.name == "bar"
+ assert cur.row_factory is rows.dict_row # type: ignore
+
+ async with psycopg.AsyncServerCursor(
+ aconn, "baz", row_factory=rows.namedtuple_row
+ ) as cur:
+ assert cur.name == "baz"
+ assert cur.row_factory is rows.namedtuple_row # type: ignore
+
+
+async def test_init_params(aconn):
+ async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+ assert cur.scrollable is None
+ assert cur.withhold is False
+
+ async with psycopg.AsyncServerCursor(
+ aconn, "bar", withhold=True, scrollable=False
+ ) as cur:
+ assert cur.scrollable is False
+ assert cur.withhold is True
+
+
+@pytest.mark.crdb_skip("cursor invalid name")
+async def test_funny_name(aconn):
+ cur = aconn.cursor("1-2-3")
+ await cur.execute("select generate_series(1, 3) as bar")
+ assert await cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.name == "1-2-3"
+ await cur.close()
+
+
+async def test_repr(aconn):
+ cur = aconn.cursor("my-name")
+ assert "psycopg.AsyncServerCursor" in str(cur)
+ assert "my-name" in repr(cur)
+ await cur.close()
+
+
+async def test_connection(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.connection is aconn
+ await cur.close()
+
+
+async def test_description(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.name == "foo"
+ await cur.execute("select generate_series(1, 10)::int4 as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0
+ await cur.close()
+
+
+async def test_format(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.format == Format.TEXT
+ await cur.close()
+
+ cur = aconn.cursor("foo", binary=True)
+ assert cur.format == Format.BINARY
+ await cur.close()
+
+
+async def test_query_params(aconn):
+ async with aconn.cursor("foo") as cur:
+ assert cur._query is None
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur._query is not None
+ assert b"declare" in cur._query.query.lower()
+ assert b"(1, $1)" in cur._query.query.lower()
+ assert cur._query.params == [bytes([0, 3])] # 3 as binary int2
+
+
+async def test_binary_cursor_execute(aconn):
+ cur = aconn.cursor("foo", binary=True)
+ await cur.execute("select generate_series(1, 2)::int4")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+ await cur.close()
+
+
+async def test_execute_binary(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 2)::int4", binary=True)
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+
+ await cur.execute("select generate_series(1, 1)")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ await cur.close()
+
+
+async def test_binary_cursor_text_override(aconn):
+ cur = aconn.cursor("foo", binary=True)
+ await cur.execute("select generate_series(1, 2)", binary=False)
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"2"
+
+ await cur.execute("select generate_series(1, 2)::int4")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ await cur.close()
+
+
+async def test_close(aconn, recwarn):
+ if aconn.info.transaction_status == aconn.TransactionStatus.INTRANS:
+ # connection dirty from previous failure
+ await aconn.execute("close foo")
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 10) as bar")
+ await cur.close()
+ assert cur.closed
+
+ assert not await (
+ await aconn.execute("select * from pg_cursors where name = 'foo'")
+ ).fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_idempotent(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select 1")
+ await cur.fetchall()
+ await cur.close()
+ await cur.close()
+
+
+async def test_close_broken_conn(aconn):
+ cur = aconn.cursor("foo")
+ await aconn.close()
+ await cur.close()
+ assert cur.closed
+
+
+async def test_cursor_close_fetchone(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ for _ in range(5):
+ await cur.fetchone()
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchone()
+
+
+async def test_cursor_close_fetchmany(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchmany(2)) == 2
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchmany(2)
+
+
+async def test_cursor_close_fetchall(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchall()) == 10
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchall()
+
+
+async def test_close_noop(aconn, recwarn):
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.close()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_on_error(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select 1")
+ with pytest.raises(e.ProgrammingError):
+ await aconn.execute("wat")
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+ await cur.close()
+
+
+async def test_pgresult(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert cur.pgresult
+ await cur.close()
+ assert not cur.pgresult
+
+
+async def test_context(aconn, recwarn):
+ recwarn.clear()
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, 10) as bar")
+
+ assert cur.closed
+ assert not await (
+ await aconn.execute("select * from pg_cursors where name = 'foo'")
+ ).fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_no_clobber(aconn):
+ with pytest.raises(e.DivisionByZero):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select 1 / %s", (0,))
+ await cur.fetchall()
+
+
+async def test_warn_close(aconn, recwarn):
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 10) as bar")
+ del cur
+ assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+async def test_execute_reuse(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as foo", (3,))
+ assert await cur.fetchone() == (1,)
+
+ await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
+ assert await cur.fetchone() == ("hello", "world")
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["text"].oid
+ assert cur.description[1].name == "baz"
+
+
+@pytest.mark.parametrize(
+ "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"]
+)
+async def test_execute_error(aconn, stmt):
+ cur = aconn.cursor("foo")
+ with pytest.raises(e.ProgrammingError):
+ await cur.execute(stmt)
+ await cur.close()
+
+
+async def test_executemany(aconn):
+ cur = aconn.cursor("foo")
+ with pytest.raises(e.NotSupportedError):
+ await cur.executemany("select %s", [(1,), (2,)])
+ await cur.close()
+
+
+async def test_fetchone(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchone() == (2,)
+ assert await cur.fetchone() is None
+
+
+async def test_fetchmany(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert await cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert await cur.fetchone() == (4,)
+ assert await cur.fetchmany(3) == [(5,)]
+ assert await cur.fetchmany(3) == []
+
+
+async def test_fetchall(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchall() == [(1,), (2,), (3,)]
+ assert await cur.fetchall() == []
+
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchall() == [(2,), (3,)]
+ assert await cur.fetchall() == []
+
+
+async def test_nextset(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert not cur.nextset()
+
+
+async def test_no_result(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar where false", (3,))
+ assert len(cur.description) == 1
+ assert (await cur.fetchall()) == []
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_standard_row_factory(aconn, row_factory):
+ if row_factory == "tuple_row":
+ getter = lambda r: r[0] # noqa: E731
+ elif row_factory == "dict_row":
+ getter = lambda r: r["bar"] # noqa: E731
+ elif row_factory == "namedtuple_row":
+ getter = lambda r: r.bar # noqa: E731
+ else:
+ assert False, row_factory
+
+ row_factory = getattr(rows, row_factory)
+ async with aconn.cursor("foo", row_factory=row_factory) as cur:
+ await cur.execute("select generate_series(1, 5) as bar")
+ assert getter(await cur.fetchone()) == 1
+ assert list(map(getter, await cur.fetchmany(2))) == [2, 3]
+ assert list(map(getter, await cur.fetchall())) == [4, 5]
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_row_factory(aconn):
+ n = 0
+
+ def my_row_factory(cur):
+ nonlocal n
+ n += 1
+ return lambda values: [n] + [-v for v in values]
+
+ cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+ await cur.execute("select generate_series(1, 3) as x")
+ recs = await cur.fetchall()
+ await cur.scroll(0, "absolute")
+ while True:
+ rec = await cur.fetchone()
+ if not rec:
+ break
+ recs.append(rec)
+ assert recs == [[1, -1], [1, -2], [1, -3]] * 2
+
+ await cur.scroll(0, "absolute")
+ cur.row_factory = rows.dict_row
+ assert await cur.fetchone() == {"x": 1}
+ await cur.close()
+
+
+async def test_rownumber(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ await cur.fetchall()
+ assert cur.rownumber == 42
+ await cur.close()
+
+
+async def test_iter(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(1,), (2,), (3,)]
+
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(2,), (3,)]
+
+
+async def test_iter_rownumber(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ async for row in cur:
+ assert cur.rownumber == row[0]
+
+
+async def test_itersize(aconn, acommands):
+ async with aconn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ acommands.popall() # flush begin and other noise
+
+ async for rec in cur:
+ pass
+ cmds = acommands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert "fetch forward 2" in cmd.lower()
+
+
+async def test_cant_scroll_by_default(aconn):
+ cur = aconn.cursor("tmp")
+ assert cur.scrollable is None
+ with pytest.raises(e.ProgrammingError):
+ await cur.scroll(0)
+ await cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_scroll(aconn):
+ cur = aconn.cursor("tmp", scrollable=True)
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(9, mode="wat")
+ await cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_scrollable(aconn):
+ curs = aconn.cursor("foo", scrollable=True)
+ assert curs.scrollable is True
+ await curs.execute("select generate_series(0, 5)")
+ await curs.scroll(5)
+ for i in range(4, -1, -1):
+ await curs.scroll(-1)
+ assert i == (await curs.fetchone())[0]
+ await curs.scroll(-1)
+ await curs.close()
+
+
+async def test_non_scrollable(aconn):
+ curs = aconn.cursor("foo", scrollable=False)
+ assert curs.scrollable is False
+ await curs.execute("select generate_series(0, 5)")
+ await curs.scroll(5)
+ with pytest.raises(e.OperationalError):
+ await curs.scroll(-1)
+ await curs.close()
+
+
+@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
+async def test_no_hold(aconn, kwargs):
+ async with aconn.cursor("foo", **kwargs) as curs:
+ assert curs.withhold is False
+ await curs.execute("select generate_series(0, 2)")
+ assert await curs.fetchone() == (0,)
+ await aconn.commit()
+ with pytest.raises(e.InvalidCursorName):
+ await curs.fetchone()
+
+
+@pytest.mark.crdb_skip("cursor with hold")
+async def test_hold(aconn):
+ async with aconn.cursor("foo", withhold=True) as curs:
+ assert curs.withhold is True
+ await curs.execute("select generate_series(0, 5)")
+ assert await curs.fetchone() == (0,)
+ await aconn.commit()
+ assert await curs.fetchone() == (1,)
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+async def test_steal_cursor(aconn, row_factory):
+ cur1 = aconn.cursor()
+ await cur1.execute(
+ "declare test cursor without hold for select generate_series(1, 6) as s"
+ )
+
+ cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory))
+ # can call fetch without execute
+ rec = await cur2.fetchone()
+ assert rec == (1,)
+ if row_factory == "namedtuple_row":
+ assert rec.s == 1
+ assert await cur2.fetchmany(3) == [(2,), (3,), (4,)]
+ assert await cur2.fetchall() == [(5,), (6,)]
+ await cur2.close()
+
+
+async def test_stolen_cursor_close(aconn):
+ cur1 = aconn.cursor()
+ await cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = aconn.cursor("test")
+ await cur2.close()
+
+ await cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = aconn.cursor("test")
+ await cur2.close()
diff --git a/tests/test_sql.py b/tests/test_sql.py
new file mode 100644
index 0000000..42b6c63
--- /dev/null
+++ b/tests/test_sql.py
@@ -0,0 +1,604 @@
+# test_sql.py - tests for the psycopg2.sql module
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import datetime as dt
+
+import pytest
+
+from psycopg import pq, sql, ProgrammingError
+from psycopg.adapt import PyFormat
+from psycopg._encodings import py2pgenc
+from psycopg.types import TypeInfo
+from psycopg.types.string import StrDumper
+
+from .utils import eur
+from .fix_crdb import crdb_encoding, crdb_scs_off
+
+
+@pytest.mark.parametrize(
+ "obj, quoted",
+ [
+ ("foo\\bar", " E'foo\\\\bar'"),
+ ("hello", "'hello'"),
+ (42, "42"),
+ (True, "true"),
+ (None, "NULL"),
+ ],
+)
+def test_quote(obj, quoted):
+ assert sql.quote(obj) == quoted
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_quote_roundtrip(conn, scs):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+
+ for i in range(1, 256):
+ want = chr(i)
+ quoted = sql.quote(want)
+ got = conn.execute(f"select {quoted}::text").fetchone()[0]
+ assert want == got
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages, f"error with {want!r}"
+
+
+@pytest.mark.parametrize("dummy", [crdb_scs_off("off")])
+def test_quote_stable_despite_deranged_libpq(conn, dummy):
+ # Verify the libpq behaviour of PQescapeString using the last setting seen.
+ # Check that we are not affected by it.
+ good_str = " E'\\\\'"
+ good_bytes = " E'\\\\000'::bytea"
+ conn.execute("set standard_conforming_strings to on")
+ assert pq.Escaping().escape_string(b"\\") == b"\\"
+ assert sql.quote("\\") == good_str
+ assert pq.Escaping().escape_bytea(b"\x00") == b"\\000"
+ assert sql.quote(b"\x00") == good_bytes
+
+ conn.execute("set standard_conforming_strings to off")
+ assert pq.Escaping().escape_string(b"\\") == b"\\\\"
+ assert sql.quote("\\") == good_str
+ assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000"
+ assert sql.quote(b"\x00") == good_bytes
+
+ # Verify that the good values are actually good
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute("set escape_string_warning to on")
+ for scs in ("on", "off"):
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ cur = conn.execute(f"select {good_str}, {good_bytes}::bytea")
+ assert cur.fetchone() == ("\\", b"\x00")
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
+
+class TestSqlFormat:
+ def test_pos(self, conn):
+ s = sql.SQL("select {} from {}").format(
+ sql.Identifier("field"), sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_pos_spec(self, conn):
+ s = sql.SQL("select {0} from {1}").format(
+ sql.Identifier("field"), sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ s = sql.SQL("select {1} from {0}").format(
+ sql.Identifier("table"), sql.Identifier("field")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_dict(self, conn):
+ s = sql.SQL("select {f} from {t}").format(
+ f=sql.Identifier("field"), t=sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_compose_literal(self, conn):
+ s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31)))
+ s1 = s.as_string(conn)
+ assert s1 == "select '2016-12-31'::date;"
+
+ def test_compose_empty(self, conn):
+ s = sql.SQL("select foo;").format()
+ s1 = s.as_string(conn)
+ assert s1 == "select foo;"
+
+ def test_percent_escape(self, conn):
+ s = sql.SQL("42 % {0}").format(sql.Literal(7))
+ s1 = s.as_string(conn)
+ assert s1 == "42 % 7"
+
+ def test_braces_escape(self, conn):
+ s = sql.SQL("{{{0}}}").format(sql.Literal(7))
+ assert s.as_string(conn) == "{7}"
+ s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
+ assert s.as_string(conn) == "{1,7}"
+
+ def test_compose_badnargs(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {0};").format()
+
+ def test_compose_badnargs_auto(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {};").format()
+ with pytest.raises(ValueError):
+ sql.SQL("select {} {1};").format(10, 20)
+ with pytest.raises(ValueError):
+ sql.SQL("select {0} {};").format(10, 20)
+
+ def test_compose_bad_args_type(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {0};").format(a=10)
+ with pytest.raises(KeyError):
+ sql.SQL("select {x};").format(10)
+
+ def test_no_modifiers(self):
+ with pytest.raises(ValueError):
+ sql.SQL("select {a!r};").format(a=10)
+ with pytest.raises(ValueError):
+ sql.SQL("select {a:<};").format(a=10)
+
+ def test_must_be_adaptable(self, conn):
+ class Foo:
+ pass
+
+ s = sql.SQL("select {0};").format(sql.Literal(Foo()))
+ with pytest.raises(ProgrammingError):
+ s.as_string(conn)
+
+ def test_auto_literal(self, conn):
+ s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1))
+ assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'::date"
+
+ def test_execute(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+ cur.execute(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier("test_compose"),
+ sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
+ (sql.Placeholder() * 3).join(", "),
+ ),
+ (10, "a", "b", "c"),
+ )
+
+ cur.execute("select * from test_compose")
+ assert cur.fetchall() == [(10, "a", "b", "c")]
+
+ def test_executemany(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+ cur.executemany(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier("test_compose"),
+ sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
+ (sql.Placeholder() * 3).join(", "),
+ ),
+ [(10, "a", "b", "c"), (20, "d", "e", "f")],
+ )
+
+ cur.execute("select * from test_compose")
+ assert cur.fetchall() == [(10, "a", "b", "c"), (20, "d", "e", "f")]
+
+ @pytest.mark.crdb_skip("copy")
+ def test_copy(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+
+ with cur.copy(
+ sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format(
+ t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")
+ ),
+ ) as copy:
+ copy.write_row((10, "a", "b", "c"))
+ copy.write_row((20, "d", "e", "f"))
+
+ with cur.copy(
+ sql.SQL("copy (select {f} from {t} order by id) to stdout").format(
+ t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")
+ )
+ ) as copy:
+ assert list(copy) == [b"c\n", b"f\n"]
+
+
+class TestIdentifier:
+ def test_class(self):
+ assert issubclass(sql.Identifier, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.Identifier("foo"), sql.Identifier)
+ assert isinstance(sql.Identifier("foo"), sql.Identifier)
+ assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier)
+ with pytest.raises(TypeError):
+ sql.Identifier()
+ with pytest.raises(TypeError):
+ sql.Identifier(10) # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ sql.Identifier(dt.date(2016, 12, 31)) # type: ignore[arg-type]
+
+ def test_repr(self):
+ obj = sql.Identifier("fo'o")
+ assert repr(obj) == 'Identifier("fo\'o")'
+ assert repr(obj) == str(obj)
+
+ obj = sql.Identifier("fo'o", 'ba"r')
+ assert repr(obj) == "Identifier(\"fo'o\", 'ba\"r')"
+ assert repr(obj) == str(obj)
+
+ def test_eq(self):
+ assert sql.Identifier("foo") == sql.Identifier("foo")
+ assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar")
+ assert sql.Identifier("foo") != sql.Identifier("bar")
+ assert sql.Identifier("foo") != "foo"
+ assert sql.Identifier("foo") != sql.SQL("foo")
+
+ @pytest.mark.parametrize(
+ "args, want",
+ [
+ (("foo",), '"foo"'),
+ (("foo", "bar"), '"foo"."bar"'),
+ (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
+ ],
+ )
+ def test_as_string(self, conn, args, want):
+ assert sql.Identifier(*args).as_string(conn) == want
+
+ @pytest.mark.parametrize(
+ "args, want, enc",
+ [
+ crdb_encoding(("foo",), '"foo"', "ascii"),
+ crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
+ crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
+ (("foo", eur), f'"foo"."{eur}"', "utf8"),
+ crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
+ ],
+ )
+ def test_as_bytes(self, conn, args, want, enc):
+ want = want.encode(enc)
+ conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
+ assert sql.Identifier(*args).as_bytes(conn) == want
+
+ def test_join(self):
+ assert not hasattr(sql.Identifier("foo"), "join")
+
+
+class TestLiteral:
+ def test_class(self):
+ assert issubclass(sql.Literal, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.Literal("foo"), sql.Literal)
+ assert isinstance(sql.Literal("foo"), sql.Literal)
+ assert isinstance(sql.Literal(b"foo"), sql.Literal)
+ assert isinstance(sql.Literal(42), sql.Literal)
+ assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal)
+
+ def test_repr(self):
+ assert repr(sql.Literal("foo")) == "Literal('foo')"
+ assert str(sql.Literal("foo")) == "Literal('foo')"
+
+ def test_as_string(self, conn):
+ assert sql.Literal(None).as_string(conn) == "NULL"
+ assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
+ assert sql.Literal(42).as_string(conn) == "42"
+ assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date"
+
+ def test_as_bytes(self, conn):
+ assert sql.Literal(None).as_bytes(conn) == b"NULL"
+ assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
+ assert sql.Literal(42).as_bytes(conn) == b"42"
+ assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_as_bytes_encoding(self, conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode(encoding)
+
+ def test_eq(self):
+ assert sql.Literal("foo") == sql.Literal("foo")
+ assert sql.Literal("foo") != sql.Literal("bar")
+ assert sql.Literal("foo") != "foo"
+ assert sql.Literal("foo") != sql.SQL("foo")
+
+ def test_must_be_adaptable(self, conn):
+ class Foo:
+ pass
+
+ with pytest.raises(ProgrammingError):
+ sql.Literal(Foo()).as_string(conn)
+
+ def test_array(self, conn):
+ assert (
+ sql.Literal([dt.date(2000, 1, 1)]).as_string(conn)
+ == "'{2000-01-01}'::date[]"
+ )
+
+ def test_short_name_builtin(self, conn):
+ assert sql.Literal(dt.time(0, 0)).as_string(conn) == "'00:00:00'::time"
+ assert (
+ sql.Literal(dt.datetime(2000, 1, 1)).as_string(conn)
+ == "'2000-01-01 00:00:00'::timestamp"
+ )
+ assert (
+ sql.Literal([dt.datetime(2000, 1, 1)]).as_string(conn)
+ == "'{\"2000-01-01 00:00:00\"}'::timestamp[]"
+ )
+
+ def test_text_literal(self, conn):
+ conn.adapters.register_dumper(str, StrDumper)
+ assert sql.Literal("foo").as_string(conn) == "'foo'"
+
+ @pytest.mark.crdb_skip("composite") # create type, actually
+ @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"])
+ def test_invalid_name(self, conn, name):
+ conn.execute(
+ f"""
+ set client_encoding to utf8;
+ create type "{name}";
+ create function invin(cstring) returns "{name}"
+ language internal immutable strict as 'textin';
+ create function invout("{name}") returns cstring
+ language internal immutable strict as 'textout';
+ create type "{name}" (input=invin, output=invout, like=text);
+ """
+ )
+ info = TypeInfo.fetch(conn, f'"{name}"')
+
+ class InvDumper(StrDumper):
+ oid = info.oid
+
+ def dump(self, obj):
+ rv = super().dump(obj)
+ return b"%s-inv" % rv
+
+ info.register(conn)
+ conn.adapters.register_dumper(str, InvDumper)
+
+ assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\""
+ cur = conn.execute(sql.SQL("select {}").format("hello"))
+ assert cur.fetchone()[0] == "hello-inv"
+
+ assert (
+ sql.Literal(["hello"]).as_string(conn) == f"'{{hello-inv}}'::\"{name}\"[]"
+ )
+ cur = conn.execute(sql.SQL("select {}").format(["hello"]))
+ assert cur.fetchone()[0] == ["hello-inv"]
+
+
+class TestSQL:
+ def test_class(self):
+ assert issubclass(sql.SQL, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.SQL("foo"), sql.SQL)
+ assert isinstance(sql.SQL("foo"), sql.SQL)
+ with pytest.raises(TypeError):
+ sql.SQL(10) # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ sql.SQL(dt.date(2016, 12, 31)) # type: ignore[arg-type]
+
+ def test_repr(self, conn):
+ assert repr(sql.SQL("foo")) == "SQL('foo')"
+ assert str(sql.SQL("foo")) == "SQL('foo')"
+ assert sql.SQL("foo").as_string(conn) == "foo"
+
+ def test_eq(self):
+ assert sql.SQL("foo") == sql.SQL("foo")
+ assert sql.SQL("foo") != sql.SQL("bar")
+ assert sql.SQL("foo") != "foo"
+ assert sql.SQL("foo") != sql.Literal("foo")
+
+ def test_sum(self, conn):
+ obj = sql.SQL("foo") + sql.SQL("bar")
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foobar"
+
+ def test_sum_inplace(self, conn):
+ obj = sql.SQL("f") + sql.SQL("oo")
+ obj += sql.SQL("bar")
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foobar"
+
+ def test_multiply(self, conn):
+ obj = sql.SQL("foo") * 3
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foofoofoo"
+
+ def test_join(self, conn):
+ obj = sql.SQL(", ").join(
+ [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
+ )
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == '"foo", bar, 42'
+
+ obj = sql.SQL(", ").join(
+ sql.Composed([sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)])
+ )
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == '"foo", bar, 42'
+
+ obj = sql.SQL(", ").join([])
+ assert obj == sql.Composed([])
+
+ def test_as_string(self, conn):
+ assert sql.SQL("foo").as_string(conn) == "foo"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_as_bytes(self, conn, encoding):
+ if encoding:
+ conn.execute(f"set client_encoding to {encoding}")
+
+ assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding)
+
+
+class TestComposed:
+ def test_class(self):
+ assert issubclass(sql.Composed, sql.Composable)
+
+ def test_repr(self):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ assert repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])"""
+ assert str(obj) == repr(obj)
+
+ def test_eq(self):
+ L = [sql.Literal("foo"), sql.Identifier("b'ar")]
+ l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
+ assert sql.Composed(L) == sql.Composed(list(L))
+ assert sql.Composed(L) != L
+ assert sql.Composed(L) != sql.Composed(l2)
+
+ def test_join(self, conn):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ obj = obj.join(", ")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "'foo', \"b'ar\""
+
+ def test_auto_literal(self, conn):
+ obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)])
+ obj = obj.join(", ")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'::date"
+
+ def test_sum(self, conn):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj = obj + sql.Literal("bar")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "foo 'bar'"
+
+ def test_sum_inplace(self, conn):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Literal("bar")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "foo 'bar'"
+
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Composed([sql.Literal("bar")])
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "foo 'bar'"
+
+ def test_iter(self):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+ it = iter(obj)
+ i = next(it)
+ assert i == sql.SQL("foo")
+ i = next(it)
+ assert i == sql.SQL("bar")
+ with pytest.raises(StopIteration):
+ next(it)
+
+ def test_as_string(self, conn):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+ assert obj.as_string(conn) == "foobar"
+
+ def test_as_bytes(self, conn):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+ assert obj.as_bytes(conn) == b"foobar"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_as_bytes_encoding(self, conn, encoding):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL(eur)])
+ conn.execute(f"set client_encoding to {encoding}")
+ assert obj.as_bytes(conn) == ("foo" + eur).encode(encoding)
+
+
+class TestPlaceholder:
+ def test_class(self):
+ assert issubclass(sql.Placeholder, sql.Composable)
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_repr_format(self, conn, format):
+ ph = sql.Placeholder(format=format)
+ add = f"format={format.name}" if format != PyFormat.AUTO else ""
+ assert str(ph) == repr(ph) == f"Placeholder({add})"
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_repr_name_format(self, conn, format):
+ ph = sql.Placeholder("foo", format=format)
+ add = f", format={format.name}" if format != PyFormat.AUTO else ""
+ assert str(ph) == repr(ph) == f"Placeholder('foo'{add})"
+
+ def test_bad_name(self):
+ with pytest.raises(ValueError):
+ sql.Placeholder(")")
+
+ def test_eq(self):
+ assert sql.Placeholder("foo") == sql.Placeholder("foo")
+ assert sql.Placeholder("foo") != sql.Placeholder("bar")
+ assert sql.Placeholder("foo") != "foo"
+ assert sql.Placeholder() == sql.Placeholder()
+ assert sql.Placeholder("foo") != sql.Placeholder()
+ assert sql.Placeholder("foo") != sql.Literal("foo")
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_as_string(self, conn, format):
+ ph = sql.Placeholder(format=format)
+ assert ph.as_string(conn) == f"%{format.value}"
+
+ ph = sql.Placeholder(name="foo", format=format)
+ assert ph.as_string(conn) == f"%(foo){format.value}"
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_as_bytes(self, conn, format):
+ ph = sql.Placeholder(format=format)
+ assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii")
+
+ ph = sql.Placeholder(name="foo", format=format)
+ assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii")
+
+
+class TestValues:
+ def test_null(self, conn):
+ assert isinstance(sql.NULL, sql.SQL)
+ assert sql.NULL.as_string(conn) == "NULL"
+
+ def test_default(self, conn):
+ assert isinstance(sql.DEFAULT, sql.SQL)
+ assert sql.DEFAULT.as_string(conn) == "DEFAULT"
+
+
+def no_e(s):
+ """Drop an eventual E from E'' quotes"""
+ if isinstance(s, memoryview):
+ s = bytes(s)
+
+ if isinstance(s, str):
+ return re.sub(r"\bE'", "'", s)
+ elif isinstance(s, bytes):
+ return re.sub(rb"\bE'", b"'", s)
+ else:
+ raise TypeError(f"not dealing with {type(s).__name__}: {s}")
diff --git a/tests/test_tpc.py b/tests/test_tpc.py
new file mode 100644
index 0000000..91a04e0
--- /dev/null
+++ b/tests/test_tpc.py
@@ -0,0 +1,325 @@
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+pytestmark = pytest.mark.crdb_skip("2-phase commit")
+
+
+def test_tpc_disabled(conn, pipeline):
+ val = int(conn.execute("show max_prepared_transactions").fetchone()[0])
+ if val:
+ pytest.skip("prepared transactions enabled")
+
+ conn.rollback()
+ conn.tpc_begin("x")
+ with pytest.raises(psycopg.NotSupportedError):
+ conn.tpc_prepare()
+
+
+class TestTPC:
+ def test_tpc_commit(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_commit_one_phase(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_commit_recovered(self, conn_cls, conn, dsn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ conn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ with conn_cls.connect(dsn) as conn:
+ xid = conn.xid(1, "gtrid", "bqual")
+ conn.tpc_commit(xid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_rollback(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_rollback')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_rollback()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_tpc_rollback_one_phase(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_rollback()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_tpc_rollback_recovered(self, conn_cls, conn, dsn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ conn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ with conn_cls.connect(dsn) as conn:
+ xid = conn.xid(1, "gtrid", "bqual")
+ conn.tpc_rollback(xid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_status_after_recover(self, conn, tpc):
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ conn.tpc_recover()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ conn.tpc_recover()
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ def test_recovered_xids(self, conn, tpc):
+ # insert a few test xns
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("begin; prepare transaction '1-foo'")
+ cur.execute("begin; prepare transaction '2-bar'")
+
+ # read the values to return
+ cur.execute(
+ """
+ select gid, prepared, owner, database from pg_prepared_xacts
+ where database = %s
+ """,
+ (conn.info.dbname,),
+ )
+ okvals = cur.fetchall()
+ okvals.sort()
+
+ xids = conn.tpc_recover()
+ xids = [xid for xid in xids if xid.database == conn.info.dbname]
+ xids.sort(key=lambda x: x.gtrid)
+
+ # check the values returned
+ assert len(okvals) == len(xids)
+ for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
+ assert xid.gtrid == gid
+ assert xid.prepared == prepared
+ assert xid.owner == owner
+ assert xid.database == database
+
+ def test_xid_encoding(self, conn, tpc):
+ xid = conn.xid(42, "gtrid", "bqual")
+ conn.tpc_begin(xid)
+ conn.tpc_prepare()
+
+ cur = conn.cursor()
+ cur.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (conn.info.dbname,),
+ )
+ assert "42_Z3RyaWQ=_YnF1YWw=" == cur.fetchone()[0]
+
+ @pytest.mark.parametrize(
+ "fid, gtrid, bqual",
+ [
+ (0, "", ""),
+ (42, "gtrid", "bqual"),
+ (0x7FFFFFFF, "x" * 64, "y" * 64),
+ ],
+ )
+ def test_xid_roundtrip(self, conn_cls, conn, dsn, tpc, fid, gtrid, bqual):
+ xid = conn.xid(fid, gtrid, bqual)
+ conn.tpc_begin(xid)
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
+
+ assert len(xids) == 1
+ xid = xids[0]
+ conn.tpc_rollback(xid)
+
+ assert xid.format_id == fid
+ assert xid.gtrid == gtrid
+ assert xid.bqual == bqual
+
+ @pytest.mark.parametrize(
+ "tid",
+ [
+ "",
+ "hello, world!",
+ "x" * 199, # PostgreSQL's limit in transaction id length
+ ],
+ )
+ def test_unparsed_roundtrip(self, conn_cls, conn, dsn, tpc, tid):
+ conn.tpc_begin(tid)
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
+
+ assert len(xids) == 1
+ xid = xids[0]
+ conn.tpc_rollback(xid)
+
+ assert xid.format_id is None
+ assert xid.gtrid == tid
+ assert xid.bqual is None
+
+ def test_xid_unicode(self, conn_cls, conn, dsn, tpc):
+ x1 = conn.xid(10, "uni", "code")
+ conn.tpc_begin(x1)
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
+ assert 10 == xid.format_id
+ assert "uni" == xid.gtrid
+ assert "code" == xid.bqual
+
+ def test_xid_unicode_unparsed(self, conn_cls, conn, dsn, tpc):
+ # We don't expect people shooting snowmen as transaction ids,
+ # so if something explodes in an encode error I don't mind.
+ # Let's just check unicode is accepted as type.
+ conn.execute("set client_encoding to utf8")
+ conn.commit()
+
+ conn.tpc_begin("transaction-id")
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "transaction-id"
+ assert xid.bqual is None
+
+ def test_cancel_fails_prepared(self, conn, tpc):
+ conn.tpc_begin("cancel")
+ conn.tpc_prepare()
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.cancel()
+
+ def test_tpc_recover_non_dbapi_connection(self, conn_cls, conn, dsn, tpc):
+ conn.row_factory = psycopg.rows.dict_row
+ conn.tpc_begin("dict-connection")
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xids = conn.tpc_recover()
+ xid = [x for x in xids if x.database == conn.info.dbname][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "dict-connection"
+ assert xid.bqual is None
+
+
+class TestXidObject:
+ def test_xid_construction(self):
+ x1 = psycopg.Xid(74, "foo", "bar")
+ 74 == x1.format_id
+ "foo" == x1.gtrid
+ "bar" == x1.bqual
+
+ def test_xid_from_string(self):
+ x2 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+ 42 == x2.format_id
+ "gtrid" == x2.gtrid
+ "bqual" == x2.bqual
+
+ x3 = psycopg.Xid.from_string("99_xxx_yyy")
+ None is x3.format_id
+ "99_xxx_yyy" == x3.gtrid
+ None is x3.bqual
+
+ def test_xid_to_string(self):
+ x1 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+ str(x1) == "42_Z3RyaWQ=_YnF1YWw="
+
+ x2 = psycopg.Xid.from_string("99_xxx_yyy")
+ str(x2) == "99_xxx_yyy"
diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py
new file mode 100644
index 0000000..a409a2e
--- /dev/null
+++ b/tests/test_tpc_async.py
@@ -0,0 +1,310 @@
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("2-phase commit"),
+]
+
+
+async def test_tpc_disabled(aconn, apipeline):
+ cur = await aconn.execute("show max_prepared_transactions")
+ val = int((await cur.fetchone())[0])
+ if val:
+ pytest.skip("prepared transactions enabled")
+
+ await aconn.rollback()
+ await aconn.tpc_begin("x")
+ with pytest.raises(psycopg.NotSupportedError):
+ await aconn.tpc_prepare()
+
+
+class TestTPC:
+ async def test_tpc_commit(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_commit()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_commit_one_phase(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_commit()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_commit_recovered(self, aconn_cls, aconn, dsn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ await aconn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = aconn.xid(1, "gtrid", "bqual")
+ await aconn.tpc_commit(xid)
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_rollback(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_rollback')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_rollback()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_tpc_rollback_one_phase(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_rollback()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_tpc_rollback_recovered(self, aconn_cls, aconn, dsn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ await aconn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = aconn.xid(1, "gtrid", "bqual")
+ await aconn.tpc_rollback(xid)
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_status_after_recover(self, aconn, tpc):
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ await aconn.tpc_recover()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+ await aconn.tpc_recover()
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ async def test_recovered_xids(self, aconn, tpc):
+ # insert a few test xns
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("begin; prepare transaction '1-foo'")
+ await cur.execute("begin; prepare transaction '2-bar'")
+
+ # read the values to return
+ await cur.execute(
+ """
+ select gid, prepared, owner, database from pg_prepared_xacts
+ where database = %s
+ """,
+ (aconn.info.dbname,),
+ )
+ okvals = await cur.fetchall()
+ okvals.sort()
+
+ xids = await aconn.tpc_recover()
+ xids = [xid for xid in xids if xid.database == aconn.info.dbname]
+ xids.sort(key=lambda x: x.gtrid)
+
+ # check the values returned
+ assert len(okvals) == len(xids)
+ for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
+ assert xid.gtrid == gid
+ assert xid.prepared == prepared
+ assert xid.owner == owner
+ assert xid.database == database
+
+ async def test_xid_encoding(self, aconn, tpc):
+ xid = aconn.xid(42, "gtrid", "bqual")
+ await aconn.tpc_begin(xid)
+ await aconn.tpc_prepare()
+
+ cur = aconn.cursor()
+ await cur.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (aconn.info.dbname,),
+ )
+ assert "42_Z3RyaWQ=_YnF1YWw=" == (await cur.fetchone())[0]
+
+ @pytest.mark.parametrize(
+ "fid, gtrid, bqual",
+ [
+ (0, "", ""),
+ (42, "gtrid", "bqual"),
+ (0x7FFFFFFF, "x" * 64, "y" * 64),
+ ],
+ )
+ async def test_xid_roundtrip(self, aconn_cls, aconn, dsn, tpc, fid, gtrid, bqual):
+ xid = aconn.xid(fid, gtrid, bqual)
+ await aconn.tpc_begin(xid)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ await aconn.tpc_rollback(xid)
+
+ assert xid.format_id == fid
+ assert xid.gtrid == gtrid
+ assert xid.bqual == bqual
+
+ @pytest.mark.parametrize(
+ "tid",
+ [
+ "",
+ "hello, world!",
+ "x" * 199, # PostgreSQL's limit in transaction id length
+ ],
+ )
+ async def test_unparsed_roundtrip(self, aconn_cls, aconn, dsn, tpc, tid):
+ await aconn.tpc_begin(tid)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ await aconn.tpc_rollback(xid)
+
+ assert xid.format_id is None
+ assert xid.gtrid == tid
+ assert xid.bqual is None
+
+ async def test_xid_unicode(self, aconn_cls, aconn, dsn, tpc):
+ x1 = aconn.xid(10, "uni", "code")
+ await aconn.tpc_begin(x1)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ][0]
+
+ assert 10 == xid.format_id
+ assert "uni" == xid.gtrid
+ assert "code" == xid.bqual
+
+ async def test_xid_unicode_unparsed(self, aconn_cls, aconn, dsn, tpc):
+ # We don't expect people shooting snowmen as transaction ids,
+ # so if something explodes in an encode error I don't mind.
+ # Let's just check unicode is accepted as type.
+ await aconn.execute("set client_encoding to utf8")
+ await aconn.commit()
+
+ await aconn.tpc_begin("transaction-id")
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "transaction-id"
+ assert xid.bqual is None
+
+ async def test_cancel_fails_prepared(self, aconn, tpc):
+ await aconn.tpc_begin("cancel")
+ await aconn.tpc_prepare()
+ with pytest.raises(psycopg.ProgrammingError):
+ aconn.cancel()
+
+ async def test_tpc_recover_non_dbapi_connection(self, aconn_cls, aconn, dsn, tpc):
+ aconn.row_factory = psycopg.rows.dict_row
+ await aconn.tpc_begin("dict-connection")
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = await aconn.tpc_recover()
+ xid = [x for x in xids if x.database == aconn.info.dbname][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "dict-connection"
+ assert xid.bqual is None
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
new file mode 100644
index 0000000..9391e00
--- /dev/null
+++ b/tests/test_transaction.py
@@ -0,0 +1,796 @@
+import sys
+import logging
+from threading import Thread, Event
+
+import pytest
+
+import psycopg
+from psycopg import Rollback
+from psycopg import errors as e
+
+# TODOCRDB: is this the expected behaviour?
+crdb_skip_external_observer = pytest.mark.crdb(
+ "skip", reason="deadlock on observer connection"
+)
+
+
+@pytest.fixture
+def conn(conn, pipeline):
+ return conn
+
+
+@pytest.fixture(autouse=True)
+def create_test_table(svcconn):
+ """Creates a table called 'test_table' for use in tests."""
+ cur = svcconn.cursor()
+ cur.execute("drop table if exists test_table")
+ cur.execute("create table test_table (id text primary key)")
+ yield
+ cur.execute("drop table test_table")
+
+
+def insert_row(conn, value):
+ sql = "INSERT INTO test_table VALUES (%s)"
+ if isinstance(conn, psycopg.Connection):
+ conn.cursor().execute(sql, (value,))
+ else:
+
+ async def f():
+ cur = conn.cursor()
+ await cur.execute(sql, (value,))
+
+ return f()
+
+
+def inserted(conn):
+ """Return the values inserted in the test table."""
+ sql = "SELECT * FROM test_table"
+ if isinstance(conn, psycopg.Connection):
+ rows = conn.cursor().execute(sql).fetchall()
+ return set(v for (v,) in rows)
+ else:
+
+ async def f():
+ cur = conn.cursor()
+ await cur.execute(sql)
+ rows = await cur.fetchall()
+ return set(v for (v,) in rows)
+
+ return f()
+
+
+def in_transaction(conn):
+ if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE:
+ return False
+ elif conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS:
+ return True
+ else:
+ assert False, conn.pgconn.transaction_status
+
+
+def get_exc_info(exc):
+ """Return the exc info for an exception or a success if exc is None"""
+ if not exc:
+ return (None,) * 3
+ try:
+ raise exc
+ except exc:
+ return sys.exc_info()
+
+
+class ExpectedException(Exception):
+ pass
+
+
+def test_basic(conn, pipeline):
+ """Basic use of transaction() to BEGIN and COMMIT a transaction."""
+ assert not in_transaction(conn)
+ with conn.transaction():
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ assert not in_transaction(conn)
+
+
+def test_exposes_associated_connection(conn):
+ """Transaction exposes its connection as a read-only property."""
+ with conn.transaction() as tx:
+ assert tx.connection is conn
+ with pytest.raises(AttributeError):
+ tx.connection = conn
+
+
+def test_exposes_savepoint_name(conn):
+ """Transaction exposes its savepoint name as a read-only property."""
+ with conn.transaction(savepoint_name="foo") as tx:
+ assert tx.savepoint_name == "foo"
+ with pytest.raises(AttributeError):
+ tx.savepoint_name = "bar"
+
+
+def test_cant_reenter(conn):
+ with conn.transaction() as tx:
+ pass
+
+ with pytest.raises(TypeError):
+ with tx:
+ pass
+
+
+def test_begins_on_enter(conn, pipeline):
+ """Transaction does not begin until __enter__() is called."""
+ tx = conn.transaction()
+ assert not in_transaction(conn)
+ with tx:
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ assert not in_transaction(conn)
+
+
+def test_commit_on_successful_exit(conn):
+ """Changes are committed on successful exit from the `with` block."""
+ with conn.transaction():
+ insert_row(conn, "foo")
+
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"foo"}
+
+
+def test_rollback_on_exception_exit(conn):
+ """Changes are rolled back if an exception escapes the `with` block."""
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "foo")
+ raise ExpectedException("This discards the insert")
+
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_context_inerror_rollback_no_clobber(conn_cls, conn, pipeline, dsn, caplog):
+ if pipeline:
+ # Only 'conn' is possibly in pipeline mode, but the transaction and
+ # checks are on 'conn2'.
+ pytest.skip("not applicable")
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ with conn_cls.connect(dsn) as conn2:
+ with conn2.transaction():
+ conn2.execute("select 1")
+ conn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ conn = conn_cls.connect(dsn)
+ try:
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+ finally:
+ conn.close()
+
+
+def test_interaction_dbapi_transaction(conn):
+ insert_row(conn, "foo")
+
+ with conn.transaction():
+ insert_row(conn, "bar")
+ raise Rollback
+
+ with conn.transaction():
+ insert_row(conn, "baz")
+
+ assert in_transaction(conn)
+ conn.commit()
+ assert inserted(conn) == {"foo", "baz"}
+
+
+def test_prohibits_use_of_commit_rollback_autocommit(conn):
+ """
+ Within a Transaction block, it is forbidden to touch commit, rollback,
+ or the autocommit setting on the connection, as this would interfere
+ with the transaction scope being managed by the Transaction block.
+ """
+ conn.autocommit = False
+ conn.commit()
+ conn.rollback()
+
+ with conn.transaction():
+ with pytest.raises(e.ProgrammingError):
+ conn.autocommit = False
+ with pytest.raises(e.ProgrammingError):
+ conn.commit()
+ with pytest.raises(e.ProgrammingError):
+ conn.rollback()
+
+ conn.autocommit = False
+ conn.commit()
+ conn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [False, True])
+def test_preserves_autocommit(conn, autocommit):
+ """
+ Connection.autocommit is unchanged both during and after Transaction block.
+ """
+ conn.autocommit = autocommit
+ with conn.transaction():
+ assert conn.autocommit is autocommit
+ assert conn.autocommit is autocommit
+
+
+def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are committed
+ """
+ conn.autocommit = False
+ assert not in_transaction(conn)
+ with conn.transaction():
+ insert_row(conn, "new")
+ assert not in_transaction(conn)
+
+ # Changes committed
+ assert inserted(conn) == {"new"}
+ assert inserted(svcconn) == {"new"}
+
+
+def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made within Transaction context are discarded
+ """
+ conn.autocommit = False
+ assert not in_transaction(conn)
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "new")
+ raise ExpectedException()
+ assert not in_transaction(conn)
+
+ # Changes discarded
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+def test_autocommit_off_and_tx_in_progress_successful_exit(conn, pipeline, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are left intact
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ conn.autocommit = False
+ insert_row(conn, "prior")
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ with conn.transaction():
+ insert_row(conn, "new")
+ assert in_transaction(conn)
+ assert inserted(conn) == {"prior", "new"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+def test_autocommit_off_and_tx_in_progress_exception_exit(conn, pipeline, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made before the Transaction context are left intact
+ * Changes made within Transaction context are discarded
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ conn.autocommit = False
+ insert_row(conn, "prior")
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "new")
+ raise ExpectedException()
+ assert in_transaction(conn)
+ assert inserted(conn) == {"prior"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn):
+ """Changes from nested transaction contexts are all persisted on exit."""
+ with conn.transaction():
+ insert_row(conn, "outer-before")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ insert_row(conn, "outer-after")
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"outer-before", "inner", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
+
+
+def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in outer context escapes.
+ """
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "outer")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in inner context escapes the outer context.
+ """
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "outer")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn):
+ """
+ An exception escaping the inner transaction context causes changes made
+ within that inner context to be discarded, but the error can then be
+ handled in the outer context, allowing changes made in the outer context
+ (both before, and after, the inner context) to be successfully committed.
+ """
+ with conn.transaction():
+ insert_row(conn, "outer-before")
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise ExpectedException()
+ insert_row(conn, "outer-after")
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"outer-before", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+def test_nested_three_levels_successful_exit(conn, svcconn):
+ """Exercise management of more than one savepoint."""
+ with conn.transaction(): # BEGIN
+ insert_row(conn, "one")
+ with conn.transaction(): # SAVEPOINT s1
+ insert_row(conn, "two")
+ with conn.transaction(): # SAVEPOINT s2
+ insert_row(conn, "three")
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"one", "two", "three"}
+ assert inserted(svcconn) == {"one", "two", "three"}
+
+
+def test_named_savepoint_escapes_savepoint_name(conn):
+ with conn.transaction("s-1"):
+ pass
+ with conn.transaction("s1; drop table students"):
+ pass
+
+
+def test_named_savepoints_successful_exit(conn, commands):
+ """
+ Entering a transaction context will do one of these these things:
+ 1. Begin an outer transaction (if one isn't already in progress)
+ 2. Begin an outer transaction and create a savepoint (if one is named)
+ 3. Create a savepoint (if a transaction is already in progress)
+ either using the name provided, or auto-generating a savepoint name.
+
+ ...and exiting the context successfully will "commit" the same.
+ """
+ # Case 1
+ # Using Transaction explicitly because conn.transaction() enters the contetx
+ assert not commands
+ with conn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 1 (with a transaction already started)
+ conn.cursor().execute("select 1")
+ assert commands.popall() == ["BEGIN"]
+ with conn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_1"']
+ assert tx.savepoint_name == "_pg3_1"
+ assert commands.popall() == ['RELEASE "_pg3_1"']
+ conn.rollback()
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ with conn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name provided)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with conn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ assert commands.popall() == ['RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with conn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ assert commands.popall() == ['RELEASE "_pg3_2"']
+ assert commands.popall() == ["COMMIT"]
+
+
+def test_named_savepoints_exception_exit(conn, commands):
+ """
+ Same as the previous test but checks that when exiting the context with an
+ exception, whatever transaction and/or savepoint was started on enter will
+ be rolled-back as appropriate.
+ """
+ # Case 1
+ with pytest.raises(ExpectedException):
+ with conn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ with pytest.raises(ExpectedException):
+ with conn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 3 (with savepoint name provided)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ with conn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ raise ExpectedException
+ assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ with conn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ raise ExpectedException
+ assert commands.popall() == [
+ 'ROLLBACK TO "_pg3_2"',
+ 'RELEASE "_pg3_2"',
+ ]
+ assert commands.popall() == ["COMMIT"]
+
+
+def test_named_savepoints_with_repeated_names_works(conn):
+ """
+ Using the same savepoint name repeatedly works correctly, but bypasses
+ some sanity checks.
+ """
+ # Works correctly if no inner transactions are rolled back
+ with conn.transaction(force_rollback=True):
+ with conn.transaction("sp"):
+ insert_row(conn, "tx1")
+ with conn.transaction("sp"):
+ insert_row(conn, "tx2")
+ with conn.transaction("sp"):
+ insert_row(conn, "tx3")
+ assert inserted(conn) == {"tx1", "tx2", "tx3"}
+
+ # Works correctly if one level of inner transaction is rolled back
+ with conn.transaction(force_rollback=True):
+ with conn.transaction("s1"):
+ insert_row(conn, "tx1")
+ with conn.transaction("s1", force_rollback=True):
+ insert_row(conn, "tx2")
+ with conn.transaction("s1"):
+ insert_row(conn, "tx3")
+ assert inserted(conn) == {"tx1"}
+ assert inserted(conn) == {"tx1"}
+
+ # Works correctly if multiple inner transactions are rolled back
+ # (This scenario mandates releasing savepoints after rolling back to them.)
+ with conn.transaction(force_rollback=True):
+ with conn.transaction("s1"):
+ insert_row(conn, "tx1")
+ with conn.transaction("s1") as tx2:
+ insert_row(conn, "tx2")
+ with conn.transaction("s1"):
+ insert_row(conn, "tx3")
+ raise Rollback(tx2)
+ assert inserted(conn) == {"tx1"}
+ assert inserted(conn) == {"tx1"}
+
+
+def test_force_rollback_successful_exit(conn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ with conn.transaction(force_rollback=True):
+ insert_row(conn, "foo")
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+def test_force_rollback_exception_exit(conn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ with pytest.raises(ExpectedException):
+ with conn.transaction(force_rollback=True):
+ insert_row(conn, "foo")
+ raise ExpectedException()
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+def test_explicit_rollback_discards_changes(conn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block exits the block and
+ discards all changes made within that block.
+
+ You can raise any of the following:
+ - Rollback (type)
+ - Rollback() (instance)
+ - Rollback(tx) (instance initialised with reference to the transaction)
+ All of these are equivalent.
+ """
+
+ def assert_no_rows():
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+ with conn.transaction():
+ insert_row(conn, "foo")
+ raise Rollback
+ assert_no_rows()
+
+ with conn.transaction():
+ insert_row(conn, "foo")
+ raise Rollback()
+ assert_no_rows()
+
+ with conn.transaction() as tx:
+ insert_row(conn, "foo")
+ raise Rollback(tx)
+ assert_no_rows()
+
+
+@crdb_skip_external_observer
+def test_explicit_rollback_outer_tx_unaffected(conn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block does not impact an
+ enclosing transaction block.
+ """
+ with conn.transaction():
+ insert_row(conn, "before")
+ with conn.transaction():
+ insert_row(conn, "during")
+ raise Rollback
+ assert in_transaction(conn)
+ assert not inserted(svcconn)
+ insert_row(conn, "after")
+ assert inserted(conn) == {"before", "after"}
+ assert inserted(svcconn) == {"before", "after"}
+
+
+def test_explicit_rollback_of_outer_transaction(conn):
+ """
+ Raising a Rollback exception that references an outer transaction will
+ discard all changes from both inner and outer transaction blocks.
+ """
+ with conn.transaction() as outer_tx:
+ insert_row(conn, "outer")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise Rollback(outer_tx)
+ assert False, "This line of code should be unreachable."
+ assert not inserted(conn)
+
+
+@crdb_skip_external_observer
+def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
+ """
+ Rolling-back an enclosing transaction does not impact an outer transaction.
+ """
+ with conn.transaction():
+ insert_row(conn, "outer-before")
+ with conn.transaction() as tx_enclosing:
+ insert_row(conn, "enclosing")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise Rollback(tx_enclosing)
+ insert_row(conn, "outer-after")
+
+ assert inserted(conn) == {"outer-before", "outer-after"}
+ assert not inserted(svcconn) # Not yet committed
+ # Changes committed
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+def test_str(conn, pipeline):
+ with conn.transaction() as tx:
+ if pipeline:
+ assert "[INTRANS, pipeline=ON]" in str(tx)
+ else:
+ assert "[INTRANS]" in str(tx)
+ assert "(active)" in str(tx)
+ assert "'" not in str(tx)
+ with conn.transaction("wat") as tx2:
+ if pipeline:
+ assert "[INTRANS, pipeline=ON]" in str(tx2)
+ else:
+ assert "[INTRANS]" in str(tx2)
+ assert "'wat'" in str(tx2)
+
+ if pipeline:
+ assert "[IDLE, pipeline=ON]" in str(tx)
+ else:
+ assert "[IDLE]" in str(tx)
+ assert "(terminated)" in str(tx)
+
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction() as tx:
+ 1 / 0
+
+ assert "(terminated)" in str(tx)
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+def test_out_of_order_exit(conn, exit_error):
+ conn.autocommit = True
+
+ t1 = conn.transaction()
+ t1.__enter__()
+
+ t2 = conn.transaction()
+ t2.__enter__()
+
+ with pytest.raises(e.ProgrammingError):
+ t1.__exit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ t2.__exit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+def test_out_of_order_implicit_begin(conn, exit_error):
+ conn.execute("select 1")
+
+ t1 = conn.transaction()
+ t1.__enter__()
+
+ t2 = conn.transaction()
+ t2.__enter__()
+
+ with pytest.raises(e.ProgrammingError):
+ t1.__exit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ t2.__exit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+def test_out_of_order_exit_same_name(conn, exit_error):
+ conn.autocommit = True
+
+ t1 = conn.transaction("save")
+ t1.__enter__()
+ t2 = conn.transaction("save")
+ t2.__enter__()
+
+ with pytest.raises(e.ProgrammingError):
+ t1.__exit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ t2.__exit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("what", ["commit", "rollback", "error"])
+def test_concurrency(conn, what):
+ conn.autocommit = True
+
+ evs = [Event() for i in range(3)]
+
+ def worker(unlock, wait_on):
+ with pytest.raises(e.ProgrammingError) as ex:
+ with conn.transaction():
+ unlock.set()
+ wait_on.wait()
+ conn.execute("select 1")
+
+ if what == "error":
+ 1 / 0
+ elif what == "rollback":
+ raise Rollback()
+ else:
+ assert what == "commit"
+
+ if what == "error":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, ZeroDivisionError)
+ elif what == "rollback":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, Rollback)
+ else:
+ assert "transaction commit" in str(ex.value)
+
+ # Start a first transaction in a thread
+ t1 = Thread(target=worker, kwargs={"unlock": evs[0], "wait_on": evs[1]})
+ t1.start()
+ evs[0].wait()
+
+ # Start a nested transaction in a thread
+ t2 = Thread(target=worker, kwargs={"unlock": evs[1], "wait_on": evs[2]})
+ t2.start()
+
+ # Terminate the first transaction before the second does
+ t1.join()
+ evs[2].set()
+ t2.join()
diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py
new file mode 100644
index 0000000..55e1c9c
--- /dev/null
+++ b/tests/test_transaction_async.py
@@ -0,0 +1,743 @@
+import asyncio
+import logging
+
+import pytest
+
+from psycopg import Rollback
+from psycopg import errors as e
+from psycopg._compat import create_task
+
+from .test_transaction import in_transaction, insert_row, inserted, get_exc_info
+from .test_transaction import ExpectedException, crdb_skip_external_observer
+from .test_transaction import create_test_table # noqa # autouse fixture
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+async def aconn(aconn, apipeline):
+ return aconn
+
+
+async def test_basic(aconn, apipeline):
+ """Basic use of transaction() to BEGIN and COMMIT a transaction."""
+ assert not in_transaction(aconn)
+ async with aconn.transaction():
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ assert not in_transaction(aconn)
+
+
+async def test_exposes_associated_connection(aconn):
+ """Transaction exposes its connection as a read-only property."""
+ async with aconn.transaction() as tx:
+ assert tx.connection is aconn
+ with pytest.raises(AttributeError):
+ tx.connection = aconn
+
+
+async def test_exposes_savepoint_name(aconn):
+ """Transaction exposes its savepoint name as a read-only property."""
+ async with aconn.transaction(savepoint_name="foo") as tx:
+ assert tx.savepoint_name == "foo"
+ with pytest.raises(AttributeError):
+ tx.savepoint_name = "bar"
+
+
+async def test_cant_reenter(aconn):
+ async with aconn.transaction() as tx:
+ pass
+
+ with pytest.raises(TypeError):
+ async with tx:
+ pass
+
+
+async def test_begins_on_enter(aconn, apipeline):
+ """Transaction does not begin until __enter__() is called."""
+ tx = aconn.transaction()
+ assert not in_transaction(aconn)
+ async with tx:
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ assert not in_transaction(aconn)
+
+
+async def test_commit_on_successful_exit(aconn):
+ """Changes are committed on successful exit from the `with` block."""
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"foo"}
+
+
+async def test_rollback_on_exception_exit(aconn):
+ """Changes are rolled back if an exception escapes the `with` block."""
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise ExpectedException("This discards the insert")
+
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_context_inerror_rollback_no_clobber(
+ aconn_cls, aconn, apipeline, dsn, caplog
+):
+ if apipeline:
+ # Only 'aconn' is possibly in pipeline mode, but the transaction and
+ # checks are on 'conn2'.
+ pytest.skip("not applicable")
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ async with await aconn_cls.connect(dsn) as conn2:
+ async with conn2.transaction():
+ await conn2.execute("select 1")
+ await aconn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ conn = await aconn_cls.connect(dsn)
+ try:
+ with pytest.raises(ZeroDivisionError):
+ async with conn.transaction():
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+ finally:
+ await conn.close()
+
+
+async def test_interaction_dbapi_transaction(aconn):
+ await insert_row(aconn, "foo")
+
+ async with aconn.transaction():
+ await insert_row(aconn, "bar")
+ raise Rollback
+
+ async with aconn.transaction():
+ await insert_row(aconn, "baz")
+
+ assert in_transaction(aconn)
+ await aconn.commit()
+ assert await inserted(aconn) == {"foo", "baz"}
+
+
+async def test_prohibits_use_of_commit_rollback_autocommit(aconn):
+ """
+ Within a Transaction block, it is forbidden to touch commit, rollback,
+ or the autocommit setting on the connection, as this would interfere
+ with the transaction scope being managed by the Transaction block.
+ """
+ await aconn.set_autocommit(False)
+ await aconn.commit()
+ await aconn.rollback()
+
+ async with aconn.transaction():
+ with pytest.raises(e.ProgrammingError):
+ await aconn.set_autocommit(False)
+ with pytest.raises(e.ProgrammingError):
+ await aconn.commit()
+ with pytest.raises(e.ProgrammingError):
+ await aconn.rollback()
+
+ await aconn.set_autocommit(False)
+ await aconn.commit()
+ await aconn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [False, True])
+async def test_preserves_autocommit(aconn, autocommit):
+ """
+ Connection.autocommit is unchanged both during and after Transaction block.
+ """
+ await aconn.set_autocommit(autocommit)
+ async with aconn.transaction():
+ assert aconn.autocommit is autocommit
+ assert aconn.autocommit is autocommit
+
+
+async def test_autocommit_off_but_no_tx_started_successful_exit(aconn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are committed
+ """
+ await aconn.set_autocommit(False)
+ assert not in_transaction(aconn)
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ assert not in_transaction(aconn)
+
+ # Changes committed
+ assert await inserted(aconn) == {"new"}
+ assert inserted(svcconn) == {"new"}
+
+
+async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made within Transaction context are discarded
+ """
+ await aconn.set_autocommit(False)
+ assert not in_transaction(aconn)
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+
+ # Changes discarded
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+async def test_autocommit_off_and_tx_in_progress_successful_exit(
+ aconn, apipeline, svcconn
+):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are left intact
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ await aconn.set_autocommit(False)
+ await insert_row(aconn, "prior")
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ assert in_transaction(aconn)
+ assert await inserted(aconn) == {"prior", "new"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+async def test_autocommit_off_and_tx_in_progress_exception_exit(
+ aconn, apipeline, svcconn
+):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made before the Transaction context are left intact
+ * Changes made within Transaction context are discarded
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ await aconn.set_autocommit(False)
+ await insert_row(aconn, "prior")
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ raise ExpectedException()
+ assert in_transaction(aconn)
+ assert await inserted(aconn) == {"prior"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_persisted_on_successful_exit(aconn, svcconn):
+ """Changes from nested transaction contexts are all persisted on exit."""
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ await insert_row(aconn, "outer-after")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"outer-before", "inner", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
+
+
+async def test_nested_all_changes_discarded_on_outer_exception(aconn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in outer context escapes.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in inner context escapes the outer context.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_nested_inner_scope_exception_handled_in_outer_scope(aconn, svcconn):
+ """
+ An exception escaping the inner transaction context causes changes made
+ within that inner context to be discarded, but the error can then be
+ handled in the outer context, allowing changes made in the outer context
+ (both before, and after, the inner context) to be successfully committed.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ await insert_row(aconn, "outer-after")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"outer-before", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+async def test_nested_three_levels_successful_exit(aconn, svcconn):
+ """Exercise management of more than one savepoint."""
+ async with aconn.transaction(): # BEGIN
+ await insert_row(aconn, "one")
+ async with aconn.transaction(): # SAVEPOINT s1
+ await insert_row(aconn, "two")
+ async with aconn.transaction(): # SAVEPOINT s2
+ await insert_row(aconn, "three")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"one", "two", "three"}
+ assert inserted(svcconn) == {"one", "two", "three"}
+
+
+async def test_named_savepoint_escapes_savepoint_name(aconn):
+ async with aconn.transaction("s-1"):
+ pass
+ async with aconn.transaction("s1; drop table students"):
+ pass
+
+
+async def test_named_savepoints_successful_exit(aconn, acommands):
+ """
+ Entering a transaction context will do one of these these things:
+ 1. Begin an outer transaction (if one isn't already in progress)
+ 2. Begin an outer transaction and create a savepoint (if one is named)
+ 3. Create a savepoint (if a transaction is already in progress)
+ either using the name provided, or auto-generating a savepoint name.
+
+ ...and exiting the context successfully will "commit" the same.
+ """
+ commands = acommands
+
+ # Case 1
+ # Using Transaction explicitly because conn.transaction() enters the contetx
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 1 (with a transaction already started)
+ await aconn.cursor().execute("select 1")
+ assert commands.popall() == ["BEGIN"]
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_1"']
+ assert tx.savepoint_name == "_pg3_1"
+
+ assert commands.popall() == ['RELEASE "_pg3_1"']
+ await aconn.rollback()
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ async with aconn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name provided)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ async with aconn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ assert commands.popall() == ['RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ assert commands.popall() == ['RELEASE "_pg3_2"']
+ assert commands.popall() == ["COMMIT"]
+
+
+async def test_named_savepoints_exception_exit(aconn, acommands):
+ """
+ Same as the previous test but checks that when exiting the context with an
+ exception, whatever transaction and/or savepoint was started on enter will
+ be rolled-back as appropriate.
+ """
+ commands = acommands
+
+ # Case 1
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 3 (with savepoint name provided)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ raise ExpectedException
+ assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ raise ExpectedException
+ assert commands.popall() == [
+ 'ROLLBACK TO "_pg3_2"',
+ 'RELEASE "_pg3_2"',
+ ]
+ assert commands.popall() == ["COMMIT"]
+
+
+async def test_named_savepoints_with_repeated_names_works(aconn):
+ """
+ Using the same savepoint name repeatedly works correctly, but bypasses
+ some sanity checks.
+ """
+ # Works correctly if no inner transactions are rolled back
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx3")
+ assert await inserted(aconn) == {"tx1", "tx2", "tx3"}
+
+ # Works correctly if one level of inner transaction is rolled back
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("s1", force_rollback=True):
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx3")
+ assert await inserted(aconn) == {"tx1"}
+ assert await inserted(aconn) == {"tx1"}
+
+ # Works correctly if multiple inner transactions are rolled back
+ # (This scenario mandates releasing savepoints after rolling back to them.)
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("s1") as tx2:
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx3")
+ raise Rollback(tx2)
+ assert await inserted(aconn) == {"tx1"}
+ assert await inserted(aconn) == {"tx1"}
+
+
+async def test_force_rollback_successful_exit(aconn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ async with aconn.transaction(force_rollback=True):
+ await insert_row(aconn, "foo")
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_force_rollback_exception_exit(aconn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction(force_rollback=True):
+ await insert_row(aconn, "foo")
+ raise ExpectedException()
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+async def test_explicit_rollback_discards_changes(aconn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block exits the block and
+ discards all changes made within that block.
+
+ You can raise any of the following:
+ - Rollback (type)
+ - Rollback() (instance)
+ - Rollback(tx) (instance initialised with reference to the transaction)
+ All of these are equivalent.
+ """
+
+ async def assert_no_rows():
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise Rollback
+ await assert_no_rows()
+
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise Rollback()
+ await assert_no_rows()
+
+ async with aconn.transaction() as tx:
+ await insert_row(aconn, "foo")
+ raise Rollback(tx)
+ await assert_no_rows()
+
+
+@crdb_skip_external_observer
+async def test_explicit_rollback_outer_tx_unaffected(aconn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block does not impact an
+ enclosing transaction block.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "before")
+ async with aconn.transaction():
+ await insert_row(aconn, "during")
+ raise Rollback
+ assert in_transaction(aconn)
+ assert not inserted(svcconn)
+ await insert_row(aconn, "after")
+ assert await inserted(aconn) == {"before", "after"}
+ assert inserted(svcconn) == {"before", "after"}
+
+
+async def test_explicit_rollback_of_outer_transaction(aconn):
+ """
+ Raising a Rollback exception that references an outer transaction will
+ discard all changes from both inner and outer transaction blocks.
+ """
+ async with aconn.transaction() as outer_tx:
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise Rollback(outer_tx)
+ assert False, "This line of code should be unreachable."
+ assert not await inserted(aconn)
+
+
+@crdb_skip_external_observer
+async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(aconn, svcconn):
+ """
+ Rolling-back an enclosing transaction does not impact an outer transaction.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ async with aconn.transaction() as tx_enclosing:
+ await insert_row(aconn, "enclosing")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise Rollback(tx_enclosing)
+ await insert_row(aconn, "outer-after")
+
+ assert await inserted(aconn) == {"outer-before", "outer-after"}
+ assert not inserted(svcconn) # Not yet committed
+ # Changes committed
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+async def test_str(aconn, apipeline):
+ async with aconn.transaction() as tx:
+ if apipeline:
+ assert "[INTRANS]" not in str(tx)
+ await apipeline.sync()
+ assert "[INTRANS, pipeline=ON]" in str(tx)
+ else:
+ assert "[INTRANS]" in str(tx)
+ assert "(active)" in str(tx)
+ assert "'" not in str(tx)
+ async with aconn.transaction("wat") as tx2:
+ if apipeline:
+ assert "[INTRANS, pipeline=ON]" in str(tx2)
+ else:
+ assert "[INTRANS]" in str(tx2)
+ assert "'wat'" in str(tx2)
+
+ if apipeline:
+ assert "[IDLE, pipeline=ON]" in str(tx)
+ else:
+ assert "[IDLE]" in str(tx)
+ assert "(terminated)" in str(tx)
+
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.transaction() as tx:
+ 1 / 0
+
+ assert "(terminated)" in str(tx)
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+async def test_out_of_order_exit(aconn, exit_error):
+ await aconn.set_autocommit(True)
+
+ t1 = aconn.transaction()
+ await t1.__aenter__()
+
+ t2 = aconn.transaction()
+ await t2.__aenter__()
+
+ with pytest.raises(e.ProgrammingError):
+ await t1.__aexit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ await t2.__aexit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+async def test_out_of_order_implicit_begin(aconn, exit_error):
+ await aconn.execute("select 1")
+
+ t1 = aconn.transaction()
+ await t1.__aenter__()
+
+ t2 = aconn.transaction()
+ await t2.__aenter__()
+
+ with pytest.raises(e.ProgrammingError):
+ await t1.__aexit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ await t2.__aexit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+async def test_out_of_order_exit_same_name(aconn, exit_error):
+ await aconn.set_autocommit(True)
+
+ t1 = aconn.transaction("save")
+ await t1.__aenter__()
+ t2 = aconn.transaction("save")
+ await t2.__aenter__()
+
+ with pytest.raises(e.ProgrammingError):
+ await t1.__aexit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ await t2.__aexit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("what", ["commit", "rollback", "error"])
+async def test_concurrency(aconn, what):
+ await aconn.set_autocommit(True)
+
+ evs = [asyncio.Event() for i in range(3)]
+
+ async def worker(unlock, wait_on):
+ with pytest.raises(e.ProgrammingError) as ex:
+ async with aconn.transaction():
+ unlock.set()
+ await wait_on.wait()
+ await aconn.execute("select 1")
+
+ if what == "error":
+ 1 / 0
+ elif what == "rollback":
+ raise Rollback()
+ else:
+ assert what == "commit"
+
+ if what == "error":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, ZeroDivisionError)
+ elif what == "rollback":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, Rollback)
+ else:
+ assert "transaction commit" in str(ex.value)
+
+ # Start a first transaction in a task
+ t1 = create_task(worker(unlock=evs[0], wait_on=evs[1]))
+ await evs[0].wait()
+
+ # Start a nested transaction in a task
+ t2 = create_task(worker(unlock=evs[1], wait_on=evs[2]))
+
+ # Terminate the first transaction before the second does
+ await asyncio.gather(t1)
+ evs[2].set()
+ await asyncio.gather(t2)
diff --git a/tests/test_typeinfo.py b/tests/test_typeinfo.py
new file mode 100644
index 0000000..d0e57e6
--- /dev/null
+++ b/tests/test_typeinfo.py
@@ -0,0 +1,145 @@
+import pytest
+
+import psycopg
+from psycopg import sql
+from psycopg.pq import TransactionStatus
+from psycopg.types import TypeInfo
+
+
+@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+def test_fetch(conn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ conn.execute("select 1")
+
+ assert conn.info.transaction_status == status
+ info = TypeInfo.fetch(conn, name)
+ assert conn.info.transaction_status == status
+
+ assert info.name == "text"
+ # TODO: add the schema?
+ # assert info.schema == "pg_catalog"
+
+ assert info.oid == psycopg.adapters.types["text"].oid
+ assert info.array_oid == psycopg.adapters.types["text"].array_oid
+ assert info.regtype == "text"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+async def test_fetch_async(aconn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ await aconn.execute("select 1")
+
+ assert aconn.info.transaction_status == status
+ info = await TypeInfo.fetch(aconn, name)
+ assert aconn.info.transaction_status == status
+
+ assert info.name == "text"
+ # assert info.schema == "pg_catalog"
+ assert info.oid == psycopg.adapters.types["text"].oid
+ assert info.array_oid == psycopg.adapters.types["text"].array_oid
+
+
+@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+def test_fetch_not_found(conn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ conn.execute("select 1")
+
+ assert conn.info.transaction_status == status
+ info = TypeInfo.fetch(conn, name)
+ assert conn.info.transaction_status == status
+ assert info is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+async def test_fetch_not_found_async(aconn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ await aconn.execute("select 1")
+
+ assert aconn.info.transaction_status == status
+ info = await TypeInfo.fetch(aconn, name)
+ assert aconn.info.transaction_status == status
+
+ assert info is None
+
+
+@pytest.mark.crdb_skip("composite")
+@pytest.mark.parametrize(
+ "name", ["testschema.testtype", sql.Identifier("testschema", "testtype")]
+)
+def test_fetch_by_schema_qualified_string(conn, name):
+ conn.execute("create schema if not exists testschema")
+ conn.execute("create type testschema.testtype as (foo text)")
+
+ info = TypeInfo.fetch(conn, name)
+ assert info.name == "testtype"
+ # assert info.schema == "testschema"
+ cur = conn.execute(
+ """
+ select oid, typarray from pg_type
+ where oid = 'testschema.testtype'::regtype
+ """
+ )
+ assert cur.fetchone() == (info.oid, info.array_oid)
+
+
+@pytest.mark.parametrize(
+ "name",
+ [
+ "text",
+ # TODO: support these?
+ # "pg_catalog.text",
+ # sql.Identifier("text"),
+ # sql.Identifier("pg_catalog", "text"),
+ ],
+)
+def test_registry_by_builtin_name(conn, name):
+ info = psycopg.adapters.types[name]
+ assert info.name == "text"
+ assert info.oid == 25
+
+
+def test_registry_empty():
+ r = psycopg.types.TypesRegistry()
+ assert r.get("text") is None
+ with pytest.raises(KeyError):
+ r["text"]
+
+
+@pytest.mark.parametrize("oid, aoid", [(1, 2), (1, 0), (0, 2), (0, 0)])
+def test_registry_invalid_oid(oid, aoid):
+ r = psycopg.types.TypesRegistry()
+ ti = psycopg.types.TypeInfo("test", oid, aoid)
+ r.add(ti)
+ assert r["test"] is ti
+ if oid:
+ assert r[oid] is ti
+ if aoid:
+ assert r[aoid] is ti
+ with pytest.raises(KeyError):
+ r[0]
+
+
+def test_registry_copy():
+ r = psycopg.types.TypesRegistry(psycopg.postgres.types)
+ assert r.get("text") is r["text"] is r[25]
+ assert r["text"].oid == 25
+
+
+def test_registry_isolated():
+ orig = psycopg.postgres.types
+ tinfo = orig["text"]
+ r = psycopg.types.TypesRegistry(orig)
+ tdummy = psycopg.types.TypeInfo("dummy", tinfo.oid, tinfo.array_oid)
+ r.add(tdummy)
+ assert r[25] is r["dummy"] is tdummy
+ assert orig[25] is r["text"] is tinfo
diff --git a/tests/test_typing.py b/tests/test_typing.py
new file mode 100644
index 0000000..fff9cec
--- /dev/null
+++ b/tests/test_typing.py
@@ -0,0 +1,449 @@
+import os
+
+import pytest
+
+HERE = os.path.dirname(os.path.abspath(__file__))
+
+
+@pytest.mark.parametrize(
+ "filename",
+ ["adapters_example.py", "typing_example.py"],
+)
+def test_typing_example(mypy, filename):
+ cp = mypy.run_on_file(os.path.join(HERE, filename))
+ errors = cp.stdout.decode("utf8", "replace").splitlines()
+ assert not errors
+ assert cp.returncode == 0
+
+
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "psycopg.connect()",
+ "psycopg.Connection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.tuple_row)",
+ "psycopg.Connection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.Connection[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.namedtuple_row)",
+ "psycopg.Connection[NamedTuple]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.class_row(Thing))",
+ "psycopg.Connection[Thing]",
+ ),
+ (
+ "psycopg.connect(row_factory=thing_row)",
+ "psycopg.Connection[Thing]",
+ ),
+ (
+ "psycopg.Connection.connect()",
+ "psycopg.Connection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.Connection.connect(row_factory=rows.dict_row)",
+ "psycopg.Connection[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.AsyncConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_connection_type(conn, type, mypy):
+ stmts = f"obj = {conn}"
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "conn, curs, type",
+ [
+ (
+ "psycopg.connect()",
+ "conn.cursor()",
+ "psycopg.Cursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "conn.cursor()",
+ "psycopg.Cursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "conn.cursor(row_factory=rows.namedtuple_row)",
+ "psycopg.Cursor[NamedTuple]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.class_row(Thing))",
+ "conn.cursor()",
+ "psycopg.Cursor[Thing]",
+ ),
+ (
+ "psycopg.connect(row_factory=thing_row)",
+ "conn.cursor()",
+ "psycopg.Cursor[Thing]",
+ ),
+ (
+ "psycopg.connect()",
+ "conn.cursor(row_factory=thing_row)",
+ "psycopg.Cursor[Thing]",
+ ),
+ # Async cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor()",
+ "psycopg.AsyncCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor(row_factory=thing_row)",
+ "psycopg.AsyncCursor[Thing]",
+ ),
+ # Server-side cursors
+ (
+ "psycopg.connect()",
+ "conn.cursor(name='foo')",
+ "psycopg.ServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "conn.cursor(name='foo')",
+ "psycopg.ServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect()",
+ "conn.cursor(name='foo', row_factory=rows.dict_row)",
+ "psycopg.ServerCursor[Dict[str, Any]]",
+ ),
+ # Async server-side cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor(name='foo')",
+ "psycopg.AsyncServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "conn.cursor(name='foo')",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor(name='foo', row_factory=rows.dict_row)",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_cursor_type(conn, curs, type, mypy):
+ stmts = f"""\
+conn = {conn}
+obj = {curs}
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "conn, curs, type",
+ [
+ (
+ "psycopg.connect()",
+ "psycopg.Cursor(conn)",
+ "psycopg.Cursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.Cursor(conn)",
+ "psycopg.Cursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.Cursor(conn, row_factory=rows.namedtuple_row)",
+ "psycopg.Cursor[NamedTuple]",
+ ),
+ # Async cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncCursor(conn)",
+ "psycopg.AsyncCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.AsyncCursor(conn)",
+ "psycopg.AsyncCursor[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncCursor(conn, row_factory=thing_row)",
+ "psycopg.AsyncCursor[Thing]",
+ ),
+ # Server-side cursors
+ (
+ "psycopg.connect()",
+ "psycopg.ServerCursor(conn, 'foo')",
+ "psycopg.ServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.ServerCursor(conn, name='foo')",
+ "psycopg.ServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.ServerCursor(conn, 'foo', row_factory=rows.namedtuple_row)",
+ "psycopg.ServerCursor[NamedTuple]",
+ ),
+ # Async server-side cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncServerCursor(conn, name='foo')",
+ "psycopg.AsyncServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.AsyncServerCursor(conn, name='foo')",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncServerCursor(conn, name='foo', row_factory=rows.dict_row)",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_cursor_type_init(conn, curs, type, mypy):
+ stmts = f"""\
+conn = {conn}
+obj = {curs}
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "curs, type",
+ [
+ (
+ "conn.cursor()",
+ "Optional[Tuple[Any, ...]]",
+ ),
+ (
+ "conn.cursor(row_factory=rows.dict_row)",
+ "Optional[Dict[str, Any]]",
+ ),
+ (
+ "conn.cursor(row_factory=thing_row)",
+ "Optional[Thing]",
+ ),
+ ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_fetchone_type(conn_class, server_side, curs, type, mypy):
+ await_ = "await" if "Async" in conn_class else ""
+ if server_side:
+ curs = curs.replace("(", "(name='foo',", 1)
+ stmts = f"""\
+conn = {await_} psycopg.{conn_class}.connect()
+curs = {curs}
+obj = {await_} curs.fetchone()
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "curs, type",
+ [
+ (
+ "conn.cursor()",
+ "Tuple[Any, ...]",
+ ),
+ (
+ "conn.cursor(row_factory=rows.dict_row)",
+ "Dict[str, Any]",
+ ),
+ (
+ "conn.cursor(row_factory=thing_row)",
+ "Thing",
+ ),
+ ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_iter_type(conn_class, server_side, curs, type, mypy):
+ if "Async" in conn_class:
+ async_ = "async "
+ await_ = "await "
+ else:
+ async_ = await_ = ""
+
+ if server_side:
+ curs = curs.replace("(", "(name='foo',", 1)
+ stmts = f"""\
+conn = {await_}psycopg.{conn_class}.connect()
+curs = {curs}
+{async_}for obj in curs:
+ pass
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize("method", ["fetchmany", "fetchall"])
+@pytest.mark.parametrize(
+ "curs, type",
+ [
+ (
+ "conn.cursor()",
+ "List[Tuple[Any, ...]]",
+ ),
+ (
+ "conn.cursor(row_factory=rows.dict_row)",
+ "List[Dict[str, Any]]",
+ ),
+ (
+ "conn.cursor(row_factory=thing_row)",
+ "List[Thing]",
+ ),
+ ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy):
+ await_ = "await" if "Async" in conn_class else ""
+ if server_side:
+ curs = curs.replace("(", "(name='foo',", 1)
+ stmts = f"""\
+conn = {await_} psycopg.{conn_class}.connect()
+curs = {curs}
+obj = {await_} curs.{method}()
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_cur_subclass_execute(mypy, conn_class, server_side):
+ async_ = "async " if "Async" in conn_class else ""
+ await_ = "await" if "Async" in conn_class else ""
+ cur_base_class = "".join(
+ [
+ "Async" if "Async" in conn_class else "",
+ "Server" if server_side else "",
+ "Cursor",
+ ]
+ )
+ cur_name = "'foo'" if server_side else ""
+
+ src = f"""\
+from typing import Any, cast
+import psycopg
+from psycopg.rows import Row, TupleRow
+
+class MyCursor(psycopg.{cur_base_class}[Row]):
+ pass
+
+{async_}def test() -> None:
+ conn = {await_} psycopg.{conn_class}.connect()
+
+ cur: MyCursor[TupleRow]
+ reveal_type(cur)
+
+ cur = cast(MyCursor[TupleRow], conn.cursor({cur_name}))
+ {async_}with cur as cur2:
+ reveal_type(cur2)
+ cur3 = {await_} cur2.execute("")
+ reveal_type(cur3)
+"""
+ cp = mypy.run_on_source(src)
+ out = cp.stdout.decode("utf8", "replace").splitlines()
+ assert len(out) == 3
+ types = [mypy.get_revealed(line) for line in out]
+ assert types[0] == types[1]
+ assert types[0] == types[2]
+
+
+def _test_reveal(stmts, type, mypy):
+ ignore = "" if type.startswith("Optional") else "# type: ignore[assignment]"
+ stmts = "\n".join(f" {line}" for line in stmts.splitlines())
+
+ src = f"""\
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
+from typing import Tuple, Union
+import psycopg
+from psycopg import rows
+
+class Thing:
+ def __init__(self, **kwargs: Any) -> None:
+ self.kwargs = kwargs
+
+def thing_row(
+ cur: Union[psycopg.Cursor[Any], psycopg.AsyncCursor[Any]],
+) -> Callable[[Sequence[Any]], Thing]:
+ assert cur.description
+ names = [d.name for d in cur.description]
+
+ def make_row(t: Sequence[Any]) -> Thing:
+ return Thing(**dict(zip(names, t)))
+
+ return make_row
+
+async def tmp() -> None:
+{stmts}
+ reveal_type(obj)
+
+ref: {type} = None {ignore}
+reveal_type(ref)
+"""
+ cp = mypy.run_on_source(src)
+ out = cp.stdout.decode("utf8", "replace").splitlines()
+ assert len(out) == 2, "\n".join(out)
+ got, want = [mypy.get_revealed(line) for line in out]
+ assert got == want
+
+
+@pytest.mark.xfail(reason="https://github.com/psycopg/psycopg/issues/308")
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "MyConnection.connect()",
+ "MyConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "MyConnection.connect(row_factory=rows.tuple_row)",
+ "MyConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "MyConnection.connect(row_factory=rows.dict_row)",
+ "MyConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_generic_connect(conn, type, mypy):
+ src = f"""
+from typing import Any, Dict, Tuple
+import psycopg
+from psycopg import rows
+
+class MyConnection(psycopg.Connection[rows.Row]):
+ pass
+
+obj = {conn}
+reveal_type(obj)
+
+ref: {type} = None # type: ignore[assignment]
+reveal_type(ref)
+"""
+ cp = mypy.run_on_source(src)
+ out = cp.stdout.decode("utf8", "replace").splitlines()
+ assert len(out) == 2, "\n".join(out)
+ got, want = [mypy.get_revealed(line) for line in out]
+ assert got == want
diff --git a/tests/test_waiting.py b/tests/test_waiting.py
new file mode 100644
index 0000000..63237e8
--- /dev/null
+++ b/tests/test_waiting.py
@@ -0,0 +1,159 @@
+import select # noqa: used in pytest.mark.skipif
+import socket
+import sys
+
+import pytest
+
+import psycopg
+from psycopg import waiting
+from psycopg import generators
+from psycopg.pq import ConnStatus, ExecStatus
+
+skip_if_not_linux = pytest.mark.skipif(
+ not sys.platform.startswith("linux"), reason="non-Linux platform"
+)
+
+waitfns = [
+ "wait",
+ "wait_selector",
+ pytest.param(
+ "wait_select", marks=pytest.mark.skipif("not hasattr(select, 'select')")
+ ),
+ pytest.param(
+ "wait_epoll", marks=pytest.mark.skipif("not hasattr(select, 'epoll')")
+ ),
+ pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")),
+]
+
+timeouts = [pytest.param({}, id="blank")]
+timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]]
+
+
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait_conn(dsn, timeout):
+ gen = generators.connect(dsn)
+ conn = waiting.wait_conn(gen, **timeout)
+ assert conn.status == ConnStatus.OK
+
+
+def test_wait_conn_bad(dsn):
+ gen = generators.connect("dbname=nosuchdb")
+ with pytest.raises(psycopg.OperationalError):
+ waiting.wait_conn(gen)
+
+
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@skip_if_not_linux
+def test_wait_ready(waitfn, wait, ready):
+ waitfn = getattr(waiting, waitfn)
+
+ def gen():
+ r = yield wait
+ return r
+
+ with socket.socket() as s:
+ r = waitfn(gen(), s.fileno())
+ assert r & ready
+
+
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait(pgconn, waitfn, timeout):
+ waitfn = getattr(waiting, waitfn)
+
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ (res,) = waitfn(gen, pgconn.socket, **timeout)
+ assert res.status == ExecStatus.TUPLES_OK
+
+
+@pytest.mark.parametrize("waitfn", waitfns)
+def test_wait_bad(pgconn, waitfn):
+ waitfn = getattr(waiting, waitfn)
+
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ waitfn(gen, pgconn.socket)
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(
+ "sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious"
+)
+@pytest.mark.parametrize("waitfn", waitfns)
+def test_wait_large_fd(dsn, waitfn):
+ waitfn = getattr(waiting, waitfn)
+
+ files = []
+ try:
+ try:
+ for i in range(1100):
+ files.append(open(__file__))
+ except OSError:
+ pytest.skip("can't open the number of files needed for the test")
+
+ pgconn = psycopg.pq.PGconn.connect(dsn.encode())
+ try:
+ assert pgconn.socket > 1024
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ if waitfn is waiting.wait_select:
+ with pytest.raises(ValueError):
+ waitfn(gen, pgconn.socket)
+ else:
+ (res,) = waitfn(gen, pgconn.socket)
+ assert res.status == ExecStatus.TUPLES_OK
+ finally:
+ pgconn.finish()
+ finally:
+ for f in files:
+ f.close()
+
+
+@pytest.mark.parametrize("timeout", timeouts)
+@pytest.mark.asyncio
+async def test_wait_conn_async(dsn, timeout):
+ gen = generators.connect(dsn)
+ conn = await waiting.wait_conn_async(gen, **timeout)
+ assert conn.status == ConnStatus.OK
+
+
+@pytest.mark.asyncio
+async def test_wait_conn_async_bad(dsn):
+ gen = generators.connect("dbname=nosuchdb")
+ with pytest.raises(psycopg.OperationalError):
+ await waiting.wait_conn_async(gen)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@skip_if_not_linux
+async def test_wait_ready_async(wait, ready):
+ def gen():
+ r = yield wait
+ return r
+
+ with socket.socket() as s:
+ r = await waiting.wait_async(gen(), s.fileno())
+ assert r & ready
+
+
+@pytest.mark.asyncio
+async def test_wait_async(pgconn):
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ (res,) = await waiting.wait_async(gen, pgconn.socket)
+ assert res.status == ExecStatus.TUPLES_OK
+
+
+@pytest.mark.asyncio
+async def test_wait_async_bad(pgconn):
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ socket = pgconn.socket
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ await waiting.wait_async(gen, socket)
diff --git a/tests/test_windows.py b/tests/test_windows.py
new file mode 100644
index 0000000..09e61ba
--- /dev/null
+++ b/tests/test_windows.py
@@ -0,0 +1,23 @@
+import pytest
+import asyncio
+import sys
+
+from psycopg.errors import InterfaceError
+
+
+@pytest.mark.skipif(sys.platform != "win32", reason="windows only test")
+def test_windows_error(aconn_cls, dsn):
+ loop = asyncio.ProactorEventLoop() # type: ignore[attr-defined]
+
+ async def go():
+ with pytest.raises(
+ InterfaceError,
+ match="Psycopg cannot use the 'ProactorEventLoop'",
+ ):
+ await aconn_cls.connect(dsn)
+
+ try:
+ loop.run_until_complete(go())
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ loop.close()
diff --git a/tests/types/__init__.py b/tests/types/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/types/__init__.py
diff --git a/tests/types/test_array.py b/tests/types/test_array.py
new file mode 100644
index 0000000..74c17a6
--- /dev/null
+++ b/tests/types/test_array.py
@@ -0,0 +1,338 @@
+from typing import List, Any
+from decimal import Decimal
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat, Transformer, Dumper
+from psycopg.types import TypeInfo
+from psycopg._compat import prod
+from psycopg.postgres import types as builtins
+
+
+tests_str = [
+ ([[[[[["a"]]]]]], "{{{{{{a}}}}}}"),
+ ([[[[[[None]]]]]], "{{{{{{NULL}}}}}}"),
+ ([[[[[["NULL"]]]]]], '{{{{{{"NULL"}}}}}}'),
+ (["foo", "bar", "baz"], "{foo,bar,baz}"),
+ (["foo", None, "baz"], "{foo,null,baz}"),
+ (["foo", "null", "", "baz"], '{foo,"null","",baz}'),
+ (
+ [["foo", "bar"], ["baz", "qux"], ["quux", "quuux"]],
+ "{{foo,bar},{baz,qux},{quux,quuux}}",
+ ),
+ (
+ [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]],
+ r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}',
+ ),
+]
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("type", ["text", "int4"])
+def test_dump_empty_list(conn, fmt_in, type):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}::{type}[] = %s::{type}[]", ([], "{}"))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("obj, want", tests_str)
+def test_dump_list_str(conn, obj, want, fmt_in):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}::text[] = %s::text[]", (obj, want))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_empty_list_str(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::text[]", ([],))
+ assert cur.fetchone()[0] == []
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("want, obj", tests_str)
+def test_load_list_str(conn, obj, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::text[]", (obj,))
+ assert cur.fetchone()[0] == want
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_all_chars(conn, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ c = chr(i)
+ cur.execute(f"select %{fmt_in.value}::text[]", ([c],))
+ assert cur.fetchone()[0] == [c]
+
+ a = list(map(chr, range(1, 256)))
+ a.append("\u20ac")
+ cur.execute(f"select %{fmt_in.value}::text[]", (a,))
+ assert cur.fetchone()[0] == a
+
+ s = "".join(a)
+ cur.execute(f"select %{fmt_in.value}::text[]", ([s],))
+ assert cur.fetchone()[0] == [s]
+
+
+tests_int = [
+ ([10, 20, -30], "{10,20,-30}"),
+ ([10, None, 30], "{10,null,30}"),
+ ([[10, 20], [30, 40]], "{{10,20},{30,40}}"),
+]
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("obj, want", tests_int)
+def test_dump_list_int(conn, obj, want):
+ cur = conn.cursor()
+ cur.execute("select %s::int[] = %s::int[]", (obj, want))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize(
+ "input",
+ [
+ [["a"], ["b", "c"]],
+ [["a"], []],
+ [[["a"]], ["b"]],
+ # [["a"], [["b"]]], # todo, but expensive (an isinstance per item)
+ # [True, b"a"], # TODO expensive too
+ ],
+)
+def test_bad_binary_array(input):
+ tx = Transformer()
+ with pytest.raises(psycopg.DataError):
+ tx.get_dumper(input, PyFormat.BINARY).dump(input)
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("want, obj", tests_int)
+def test_load_list_int(conn, obj, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::int[]", (obj,))
+ assert cur.fetchone()[0] == want
+
+ stmt = sql.SQL("copy (select {}::int[]) to stdout (format {})").format(
+ obj, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["int4[]"])
+ (got,) = copy.read_row()
+
+ assert got == want
+
+
+@pytest.mark.crdb_skip("composite")
+def test_array_register(conn):
+ conn.execute("create table mytype (data text)")
+ cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
+ res = cur.fetchone()
+ assert res[0] == "(foo)"
+ assert res[1] == "{(foo)}"
+
+ info = TypeInfo.fetch(conn, "mytype")
+ info.register(conn)
+
+ cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
+ res = cur.fetchone()
+ assert res[0] == "(foo)"
+ assert res[1] == ["(foo)"]
+
+
+@pytest.mark.crdb("skip", reason="aclitem")
+def test_array_of_unknown_builtin(conn):
+ user = conn.execute("select user").fetchone()[0]
+ # we cannot load this type, but we understand it is an array
+ val = f"{user}=arwdDxt/{user}"
+ cur = conn.execute(f"select '{val}'::aclitem, array['{val}']::aclitem[]")
+ res = cur.fetchone()
+ assert cur.description[0].type_code == builtins["aclitem"].oid
+ assert res[0] == val
+ assert cur.description[1].type_code == builtins["aclitem"].array_oid
+ assert res[1] == [val]
+
+
+@pytest.mark.parametrize(
+ "num, type",
+ [
+ (0, "int2"),
+ (2**15 - 1, "int2"),
+ (-(2**15), "int2"),
+ (2**15, "int4"),
+ (2**31 - 1, "int4"),
+ (-(2**31), "int4"),
+ (2**31, "int8"),
+ (2**63 - 1, "int8"),
+ (-(2**63), "int8"),
+ (2**63, "numeric"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_numbers_array(num, type, fmt_in):
+ for array in ([num], [1, num]):
+ tx = Transformer()
+ dumper = tx.get_dumper(array, fmt_in)
+ dumper.dump(array)
+ assert dumper.oid == builtins[type].array_oid
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split())
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ if wrapper is Decimal:
+ want_cls = Decimal
+ else:
+ assert wrapper.__mro__[1] in (int, float)
+ want_cls = wrapper.__mro__[1]
+
+ obj = [wrapper(1), wrapper(0), wrapper(-1), None]
+ cur = conn.cursor(binary=fmt_out)
+ got = cur.execute(f"select %{fmt_in.value}", [obj]).fetchone()[0]
+ assert got == obj
+ for i in got:
+ if i is not None:
+ assert type(i) is want_cls
+
+
+def test_mix_types(conn):
+ with pytest.raises(psycopg.DataError):
+ conn.execute("select %s", ([1, 0.5],))
+
+ with pytest.raises(psycopg.DataError):
+ conn.execute("select %s", ([1, Decimal("0.5")],))
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_empty_list_mix(conn, fmt_in):
+ objs = list(range(3))
+ conn.execute("create table testarrays (col1 bigint[], col2 bigint[])")
+ # pro tip: don't get confused with the types
+ f1, f2 = conn.execute(
+ f"insert into testarrays values (%{fmt_in.value}, %{fmt_in.value}) returning *",
+ (objs, []),
+ ).fetchone()
+ assert f1 == objs
+ assert f2 == []
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_empty_list(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table test (id serial primary key, data date[])")
+ with conn.transaction():
+ cur.execute(
+ f"insert into test (data) values (%{fmt_in.value}) returning id", ([],)
+ )
+ id = cur.fetchone()[0]
+ cur.execute("select data from test")
+ assert cur.fetchone() == ([],)
+
+ # test untyped list in a filter
+ cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([id],))
+ assert cur.fetchone()
+ cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([],))
+ assert not cur.fetchone()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_empty_list_after_choice(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table test (id serial primary key, data float[])")
+ cur.executemany(
+ f"insert into test (data) values (%{fmt_in.value})", [([1.0],), ([],)]
+ )
+ cur.execute("select data from test order by id")
+ assert cur.fetchall() == [([1.0],), ([],)]
+
+
+@pytest.mark.crdb_skip("geometric types")
+def test_dump_list_no_comma_separator(conn):
+ class Box:
+ def __init__(self, x1, y1, x2, y2):
+ self.coords = (x1, y1, x2, y2)
+
+ class BoxDumper(Dumper):
+
+ format = pq.Format.TEXT
+ oid = psycopg.postgres.types["box"].oid
+
+ def dump(self, box):
+ return ("(%s,%s),(%s,%s)" % box.coords).encode()
+
+ conn.adapters.register_dumper(Box, BoxDumper)
+
+ cur = conn.execute("select (%s::box)::text", (Box(1, 2, 3, 4),))
+ got = cur.fetchone()[0]
+ assert got == "(3,4),(1,2)"
+
+ cur = conn.execute(
+ "select (%s::box[])::text", ([Box(1, 2, 3, 4), Box(5, 4, 3, 2)],)
+ )
+ got = cur.fetchone()[0]
+ assert got == "{(3,4),(1,2);(5,4),(3,2)}"
+
+
+@pytest.mark.crdb_skip("geometric types")
+def test_load_array_no_comma_separator(conn):
+ cur = conn.execute("select '{(2,2),(1,1);(5,6),(3,4)}'::box[]")
+ # Not parsed at the moment, but split ok on ; separator
+ assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"]
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_nested_array(conn, fmt_out):
+ dims = [3, 4, 5, 6]
+ a: List[Any] = list(range(prod(dims)))
+ for dim in dims[-1:0:-1]:
+ a = [a[i : i + dim] for i in range(0, len(a), dim)]
+
+ assert a[2][3][4][5] == prod(dims) - 1
+
+ sa = str(a).replace("[", "{").replace("]", "}")
+ got = conn.execute("select %s::int[][][][]", [sa], binary=fmt_out).fetchone()[0]
+ assert got == a
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+ "obj, want",
+ [
+ ("'[0:1]={a,b}'::text[]", ["a", "b"]),
+ ("'[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}'::int[]", [[[1, 2, 3], [4, 5, 6]]]),
+ ],
+)
+def test_array_with_bounds(conn, obj, want, fmt_out):
+ got = conn.execute(f"select {obj}", binary=fmt_out).fetchone()[0]
+ assert got == want
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_all_chars_with_bounds(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ c = chr(i)
+ cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([c],))
+ assert cur.fetchone()[0] == ["a", "b", c]
+
+ a = list(map(chr, range(1, 256)))
+ a.append("\u20ac")
+ cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", (a,))
+ assert cur.fetchone()[0] == ["a", "b"] + a
+
+ s = "".join(a)
+ cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([s],))
+ assert cur.fetchone()[0] == ["a", "b", s]
diff --git a/tests/types/test_bool.py b/tests/types/test_bool.py
new file mode 100644
index 0000000..edd4dad
--- /dev/null
+++ b/tests/types/test_bool.py
@@ -0,0 +1,47 @@
+import pytest
+
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import Transformer, PyFormat
+from psycopg.postgres import types as builtins
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("b", [True, False])
+def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ result = cur.execute(f"select %{fmt_in.value}", (b,)).fetchone()[0]
+ assert cur.pgresult.fformat(0) == fmt_out
+ if b is not None:
+ assert cur.pgresult.ftype(0) == builtins["bool"].oid
+ assert result is b
+
+ result = cur.execute(f"select %{fmt_in.value}", ([b],)).fetchone()[0]
+ assert cur.pgresult.fformat(0) == fmt_out
+ if b is not None:
+ assert cur.pgresult.ftype(0) == builtins["bool"].array_oid
+ assert result[0] is b
+
+
+@pytest.mark.parametrize("val", [True, False])
+def test_quote_bool(conn, val):
+
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(val).lower().encode(
+ "ascii"
+ )
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
+ assert cur.fetchone()[0] is val
+
+
+def test_quote_none(conn):
+
+ tx = Transformer()
+ assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL"
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None)))
+ assert cur.fetchone()[0] is None
diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py
new file mode 100644
index 0000000..47beecf
--- /dev/null
+++ b/tests/types/test_composite.py
@@ -0,0 +1,396 @@
+import pytest
+
+from psycopg import pq, postgres, sql
+from psycopg.adapt import PyFormat
+from psycopg.postgres import types as builtins
+from psycopg.types.range import Range
+from psycopg.types.composite import CompositeInfo, register_composite
+from psycopg.types.composite import TupleDumper, TupleBinaryDumper
+
+from ..utils import eur
+from ..fix_crdb import is_crdb, crdb_skip_message
+
+
+pytestmark = pytest.mark.crdb_skip("composite")
+
+tests_str = [
+ ("", ()),
+ # Funnily enough there's no way to represent (None,) in Postgres
+ ("null", ()),
+ ("null,null", (None, None)),
+ ("null, ''", (None, "")),
+ (
+ "42,'foo','ba,r','ba''z','qu\"x'",
+ ("42", "foo", "ba,r", "ba'z", 'qu"x'),
+ ),
+ ("'foo''', '''foo', '\"bar', 'bar\"' ", ("foo'", "'foo", '"bar', 'bar"')),
+]
+
+
+@pytest.mark.parametrize("rec, want", tests_str)
+def test_load_record(conn, want, rec):
+ cur = conn.cursor()
+ res = cur.execute(f"select row({rec})").fetchone()[0]
+ assert res == want
+
+
+@pytest.mark.parametrize("rec, obj", tests_str)
+def test_dump_tuple(conn, rec, obj):
+ cur = conn.cursor()
+ fields = [f"f{i} text" for i in range(len(obj))]
+ cur.execute(
+ f"""
+ drop type if exists tmptype;
+ create type tmptype as ({', '.join(fields)});
+ """
+ )
+ info = CompositeInfo.fetch(conn, "tmptype")
+ register_composite(info, conn)
+
+ res = conn.execute("select %s::tmptype", [obj]).fetchone()[0]
+ assert res == obj
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_all_chars(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0]
+ assert res == (chr(i),)
+
+ cur.execute("select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256)))
+ res = cur.fetchone()[0]
+ assert res == tuple(map(chr, range(1, 256)))
+
+ s = "".join(map(chr, range(1, 256)))
+ res = cur.execute("select row(%s::text)", [s]).fetchone()[0]
+ assert res == (s,)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty_range(conn, fmt_in):
+ conn.execute(
+ """
+ drop type if exists tmptype;
+ create type tmptype as (num integer, range daterange, nums integer[])
+ """
+ )
+ info = CompositeInfo.fetch(conn, "tmptype")
+ register_composite(info, conn)
+
+ cur = conn.execute(
+ f"select pg_typeof(%{fmt_in.value})",
+ [info.python_type(10, Range(empty=True), [])],
+ )
+ assert cur.fetchone()[0] == "tmptype"
+
+
+@pytest.mark.parametrize(
+ "rec, want",
+ [
+ ("", ()),
+ ("null", (None,)), # Unlike text format, this is a thing
+ ("null,null", (None, None)),
+ ("null, ''", (None, b"")),
+ (
+ "42,'foo','ba,r','ba''z','qu\"x'",
+ (42, b"foo", b"ba,r", b"ba'z", b'qu"x'),
+ ),
+ (
+ "'foo''', '''foo', '\"bar', 'bar\"' ",
+ (b"foo'", b"'foo", b'"bar', b'bar"'),
+ ),
+ (
+ "10::int, null::text, 20::float, null::text, 'foo'::text, 'bar'::bytea ",
+ (10, None, 20.0, None, "foo", b"bar"),
+ ),
+ ],
+)
+def test_load_record_binary(conn, want, rec):
+ cur = conn.cursor(binary=True)
+ res = cur.execute(f"select row({rec})").fetchone()[0]
+ assert res == want
+ for o1, o2 in zip(res, want):
+ assert type(o1) is type(o2)
+
+
+@pytest.fixture(scope="session")
+def testcomp(svcconn):
+ if is_crdb(svcconn):
+ pytest.skip(crdb_skip_message("composite"))
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ create schema if not exists testschema;
+
+ drop type if exists testcomp cascade;
+ drop type if exists testschema.testcomp cascade;
+
+ create type testcomp as (foo text, bar int8, baz float8);
+ create type testschema.testcomp as (foo text, bar int8, qux bool);
+ """
+ )
+ return CompositeInfo.fetch(svcconn, "testcomp")
+
+
+fetch_cases = [
+ (
+ "testcomp",
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ "testschema.testcomp",
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+ (
+ sql.Identifier("testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ sql.Identifier("testschema", "testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+]
+
+
+@pytest.mark.parametrize("name, fields", fetch_cases)
+def test_fetch_info(conn, testcomp, name, fields):
+ info = CompositeInfo.fetch(conn, name)
+ assert info.name == "testcomp"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.field_names) == 3
+ assert len(info.field_types) == 3
+ for i, (name, t) in enumerate(fields):
+ assert info.field_names[i] == name
+ assert info.field_types[i] == builtins[t].oid
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, fields", fetch_cases)
+async def test_fetch_info_async(aconn, testcomp, name, fields):
+ info = await CompositeInfo.fetch(aconn, name)
+ assert info.name == "testcomp"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.field_names) == 3
+ assert len(info.field_types) == 3
+ for i, (name, t) in enumerate(fields):
+ assert info.field_names[i] == name
+ assert info.field_types[i] == builtins[t].oid
+
+
+@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT])
+def test_dump_tuple_all_chars(conn, fmt_in, testcomp):
+ cur = conn.cursor()
+ for i in range(1, 256):
+ (res,) = cur.execute(
+ f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}::testcomp",
+ (i, (chr(i), 1, 1.0)),
+ ).fetchone()
+ assert res is True
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_composite_all_chars(conn, fmt_in, testcomp):
+ cur = conn.cursor()
+ register_composite(testcomp, cur)
+ factory = testcomp.python_type
+ for i in range(1, 256):
+ obj = factory(chr(i), 1, 1.0)
+ (res,) = cur.execute(
+ f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}", (i, obj)
+ ).fetchone()
+ assert res is True
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_composite_null(conn, fmt_in, testcomp):
+ cur = conn.cursor()
+ register_composite(testcomp, cur)
+ factory = testcomp.python_type
+
+ obj = factory("foo", 1, None)
+ rec = cur.execute(
+ f"""
+ select row('foo', 1, NULL)::testcomp = %(obj){fmt_in.value},
+ %(obj){fmt_in.value}::text
+ """,
+ {"obj": obj},
+ ).fetchone()
+ assert rec[0] is True, rec[1]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_composite(conn, testcomp, fmt_out):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info, conn)
+
+ cur = conn.cursor(binary=fmt_out)
+ res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+ assert res.foo == "hello"
+ assert res.bar == 10
+ assert res.baz == 20.0
+ assert isinstance(res.baz, float)
+
+ res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
+ assert len(res) == 1
+ assert res[0].baz == 30.0
+ assert isinstance(res[0].baz, float)
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_composite_factory(conn, testcomp, fmt_out):
+ info = CompositeInfo.fetch(conn, "testcomp")
+
+ class MyThing:
+ def __init__(self, *args):
+ self.foo, self.bar, self.baz = args
+
+ register_composite(info, conn, factory=MyThing)
+ assert info.python_type is MyThing
+
+ cur = conn.cursor(binary=fmt_out)
+ res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+ assert isinstance(res, MyThing)
+ assert res.baz == 20.0
+ assert isinstance(res.baz, float)
+
+ res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
+ assert len(res) == 1
+ assert res[0].baz == 30.0
+ assert isinstance(res[0].baz, float)
+
+
+def test_register_scope(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info)
+ for fmt in pq.Format:
+ for oid in (info.oid, info.array_oid):
+ assert postgres.adapters._loaders[fmt].pop(oid)
+
+ for f in PyFormat:
+ assert postgres.adapters._dumpers[f].pop(info.python_type)
+
+ cur = conn.cursor()
+ register_composite(info, cur)
+ for fmt in pq.Format:
+ for oid in (info.oid, info.array_oid):
+ assert oid not in postgres.adapters._loaders[fmt]
+ assert oid not in conn.adapters._loaders[fmt]
+ assert oid in cur.adapters._loaders[fmt]
+
+ register_composite(info, conn)
+ for fmt in pq.Format:
+ for oid in (info.oid, info.array_oid):
+ assert oid not in postgres.adapters._loaders[fmt]
+ assert oid in conn.adapters._loaders[fmt]
+
+
+def test_type_dumper_registered(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info, conn)
+ assert issubclass(info.python_type, tuple)
+ assert info.python_type.__name__ == "testcomp"
+ d = conn.adapters.get_dumper(info.python_type, "s")
+ assert issubclass(d, TupleDumper)
+ assert d is not TupleDumper
+
+ tc = info.python_type("foo", 42, 3.14)
+ cur = conn.execute("select pg_typeof(%s)", [tc])
+ assert cur.fetchone()[0] == "testcomp"
+
+
+def test_type_dumper_registered_binary(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info, conn)
+ assert issubclass(info.python_type, tuple)
+ assert info.python_type.__name__ == "testcomp"
+ d = conn.adapters.get_dumper(info.python_type, "b")
+ assert issubclass(d, TupleBinaryDumper)
+ assert d is not TupleBinaryDumper
+
+ tc = info.python_type("foo", 42, 3.14)
+ cur = conn.execute("select pg_typeof(%b)", [tc])
+ assert cur.fetchone()[0] == "testcomp"
+
+
+def test_callable_dumper_not_registered(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+
+ def fac(*args):
+ return args + (args[-1],)
+
+ register_composite(info, conn, factory=fac)
+ assert info.python_type is None
+
+ # but the loader is registered
+ cur = conn.execute("select '(foo,42,3.14)'::testcomp")
+ assert cur.fetchone()[0] == ("foo", 42, 3.14, 3.14)
+
+
+def test_no_info_error(conn):
+ with pytest.raises(TypeError, match="composite"):
+ register_composite(None, conn) # type: ignore[arg-type]
+
+
+def test_invalid_fields_names(conn):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(
+ f"""
+ create type "a-b" as ("c-d" text, "{eur}" int);
+ create type "-x-{eur}" as ("w-ww" "a-b", "0" int);
+ """
+ )
+ ab = CompositeInfo.fetch(conn, '"a-b"')
+ x = CompositeInfo.fetch(conn, f'"-x-{eur}"')
+ register_composite(ab, conn)
+ register_composite(x, conn)
+ obj = x.python_type(ab.python_type("foo", 10), 20)
+ conn.execute(f"""create table meh (wat "-x-{eur}")""")
+ conn.execute("insert into meh values (%s)", [obj])
+ got = conn.execute("select wat from meh").fetchone()[0]
+ assert obj == got
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "1", "'"])
+def test_literal_invalid_name(conn, name):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(
+ sql.SQL("create type {name} as (foo text)").format(name=sql.Identifier(name))
+ )
+ info = CompositeInfo.fetch(conn, sql.Identifier(name).as_string(conn))
+ register_composite(info, conn)
+ obj = info.python_type("hello")
+ assert sql.Literal(obj).as_string(conn) == f"'(hello)'::\"{name}\""
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ got = cur.fetchone()[0]
+ assert got == obj
+ assert type(got) is type(obj)
+
+
+@pytest.mark.parametrize(
+ "name, attr",
+ [
+ ("a-b", "a_b"),
+ (f"{eur}", "f_"),
+ ("üåäö", "üåäö"),
+ ("order", "order"),
+ ("1", "f1"),
+ ],
+)
+def test_literal_invalid_attr(conn, name, attr):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(
+ sql.SQL("create type test_attr as ({name} text)").format(
+ name=sql.Identifier(name)
+ )
+ )
+ info = CompositeInfo.fetch(conn, "test_attr")
+ register_composite(info, conn)
+ obj = info.python_type("hello")
+ assert getattr(obj, attr) == "hello"
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ got = cur.fetchone()[0]
+ assert got == obj
+ assert type(got) is type(obj)
diff --git a/tests/types/test_datetime.py b/tests/types/test_datetime.py
new file mode 100644
index 0000000..11fe493
--- /dev/null
+++ b/tests/types/test_datetime.py
@@ -0,0 +1,813 @@
+import datetime as dt
+
+import pytest
+
+from psycopg import DataError, pq, sql
+from psycopg.adapt import PyFormat
+
+crdb_skip_datestyle = pytest.mark.crdb("skip", reason="set datestyle/intervalstyle")
+crdb_skip_negative_interval = pytest.mark.crdb("skip", reason="negative interval")
+crdb_skip_invalid_tz = pytest.mark.crdb(
+ "skip", reason="crdb doesn't allow invalid timezones"
+)
+
+datestyles_in = [
+ pytest.param(datestyle, marks=crdb_skip_datestyle)
+ for datestyle in ["DMY", "MDY", "YMD"]
+]
+datestyles_out = [
+ pytest.param(datestyle, marks=crdb_skip_datestyle)
+ for datestyle in ["ISO", "Postgres", "SQL", "German"]
+]
+
+intervalstyles = [
+ pytest.param(datestyle, marks=crdb_skip_datestyle)
+ for datestyle in ["sql_standard", "postgres", "postgres_verbose", "iso_8601"]
+]
+
+
+class TestDate:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "0001-01-01"),
+ ("1000,1,1", "1000-01-01"),
+ ("2000,1,1", "2000-01-01"),
+ ("2000,12,31", "2000-12-31"),
+ ("3000,1,1", "3000-01-01"),
+ ("max", "9999-12-31"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_date(self, conn, val, expr, fmt_in):
+ val = as_date(val)
+ cur = conn.cursor()
+ cur.execute(f"select '{expr}'::date = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(
+ sql.SQL("select {}::date = {}").format(
+ sql.Literal(val), sql.Placeholder(format=fmt_in)
+ ),
+ (val,),
+ )
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_dump_date_datestyle(self, conn, datestyle_in):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = ISO,{datestyle_in}")
+ cur.execute("select 'epoch'::date + 1 = %t", (dt.date(1970, 1, 2),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "0001-01-01"),
+ ("1000,1,1", "1000-01-01"),
+ ("2000,1,1", "2000-01-01"),
+ ("2000,12,31", "2000-12-31"),
+ ("3000,1,1", "3000-01-01"),
+ ("max", "9999-12-31"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_date(self, conn, val, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select '{expr}'::date")
+ assert cur.fetchone()[0] == as_date(val)
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ def test_load_date_datestyle(self, conn, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select '2000-01-02'::date")
+ assert cur.fetchone()[0] == dt.date(2000, 1, 2)
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ def test_load_date_overflow(self, conn, val, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %t + %s::int", (as_date(val), -1 if val == "min" else 1))
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ def test_load_date_overflow_binary(self, conn, val):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s + %s::int", (as_date(val), -1 if val == "min" else 1))
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ overflow_samples = [
+ ("-infinity", "date too small"),
+ ("1000-01-01 BC", "date too small"),
+ ("10000-01-01", "date too large"),
+ ("infinity", "date too large"),
+ ]
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_load_overflow_message(self, conn, datestyle_out, val, msg):
+ cur = conn.cursor()
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %s::date", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_load_overflow_message_binary(self, conn, val, msg):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s::date", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ def test_infinity_date_example(self, conn):
+ # NOTE: this is an example in the docs. Make sure it doesn't regress when
+ # adding binary datetime adapters
+ from datetime import date
+ from psycopg.types.datetime import DateLoader, DateDumper
+
+ class InfDateDumper(DateDumper):
+ def dump(self, obj):
+ if obj == date.max:
+ return b"infinity"
+ else:
+ return super().dump(obj)
+
+ class InfDateLoader(DateLoader):
+ def load(self, data):
+ if data == b"infinity":
+ return date.max
+ else:
+ return super().load(data)
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(date, InfDateDumper)
+ cur.adapters.register_loader("date", InfDateLoader)
+
+ rec = cur.execute(
+ "SELECT %s::text, %s::text", [date(2020, 12, 31), date.max]
+ ).fetchone()
+ assert rec == ("2020-12-31", "infinity")
+ rec = cur.execute("select '2020-12-31'::date, 'infinity'::date").fetchone()
+ assert rec == (date(2020, 12, 31), date(9999, 12, 31))
+
+
+class TestDatetime:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "0001-01-01 00:00"),
+ ("258,1,8,1,12,32,358261", "0258-1-8 1:12:32.358261"),
+ ("1000,1,1,0,0", "1000-01-01 00:00"),
+ ("2000,1,1,0,0", "2000-01-01 00:00"),
+ ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"),
+ ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"),
+ ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"),
+ ("2000,1,1,0,0,0,1", "2000-01-01 00:00:00.000001"),
+ ("2034,02,03,23,34,27,951357", "2034-02-03 23:34:27.951357"),
+ ("2200,1,1,0,0,0,1", "2200-01-01 00:00:00.000001"),
+ ("2300,1,1,0,0,0,1", "2300-01-01 00:00:00.000001"),
+ ("7000,1,1,0,0,0,1", "7000-01-01 00:00:00.000001"),
+ ("max", "9999-12-31 23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_datetime(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute("set timezone to '+02:00'")
+ cur.execute(f"select %{fmt_in.value}", (as_dt(val),))
+ cur.execute(f"select '{expr}'::timestamp = %{fmt_in.value}", (as_dt(val),))
+ cur.execute(
+ f"""
+ select '{expr}'::timestamp = %(val){fmt_in.value},
+ '{expr}', %(val){fmt_in.value}::text
+ """,
+ {"val": as_dt(val)},
+ )
+ ok, want, got = cur.fetchone()
+ assert ok, (want, got)
+
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_dump_datetime_datestyle(self, conn, datestyle_in):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = ISO, {datestyle_in}")
+ cur.execute(
+ "select 'epoch'::timestamp + '1d 3h 4m 5s'::interval = %t",
+ (dt.datetime(1970, 1, 2, 3, 4, 5),),
+ )
+ assert cur.fetchone()[0] is True
+
+ load_datetime_samples = [
+ ("min", "0001-01-01"),
+ ("1000,1,1", "1000-01-01"),
+ ("2000,1,1", "2000-01-01"),
+ ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"),
+ ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"),
+ ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"),
+ ("2000,12,31", "2000-12-31"),
+ ("3000,1,1", "3000-01-01"),
+ ("max", "9999-12-31 23:59:59.999999"),
+ ]
+
+ @pytest.mark.parametrize("val, expr", load_datetime_samples)
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_load_datetime(self, conn, val, expr, datestyle_in, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
+ cur.execute("set timezone to '+02:00'")
+ cur.execute(f"select '{expr}'::timestamp")
+ assert cur.fetchone()[0] == as_dt(val)
+
+ @pytest.mark.parametrize("val, expr", load_datetime_samples)
+ def test_load_datetime_binary(self, conn, val, expr):
+ cur = conn.cursor(binary=True)
+ cur.execute("set timezone to '+02:00'")
+ cur.execute(f"select '{expr}'::timestamp")
+ assert cur.fetchone()[0] == as_dt(val)
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ def test_load_datetime_overflow(self, conn, val, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute(
+ "select %t::timestamp + %s * '1s'::interval",
+ (as_dt(val), -1 if val == "min" else 1),
+ )
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ def test_load_datetime_overflow_binary(self, conn, val):
+ cur = conn.cursor(binary=True)
+ cur.execute(
+ "select %t::timestamp + %s * '1s'::interval",
+ (as_dt(val), -1 if val == "min" else 1),
+ )
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ overflow_samples = [
+ ("-infinity", "timestamp too small"),
+ ("1000-01-01 12:00 BC", "timestamp too small"),
+ ("10000-01-01 12:00", "timestamp too large"),
+ ("infinity", "timestamp too large"),
+ ]
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message(self, conn, datestyle_out, val, msg):
+ cur = conn.cursor()
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %s::timestamp", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message_binary(self, conn, val, msg):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s::timestamp", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @crdb_skip_datestyle
+ def test_load_all_month_names(self, conn):
+ cur = conn.cursor(binary=False)
+ cur.execute("set datestyle = 'Postgres'")
+ for i in range(12):
+ d = dt.datetime(2000, i + 1, 15)
+ cur.execute("select %s", [d])
+ assert cur.fetchone()[0] == d
+
+
+class TestDateTimeTz:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min~-2", "0001-01-01 00:00-02:00"),
+ ("min~-12", "0001-01-01 00:00-12:00"),
+ (
+ "258,1,8,1,12,32,358261~1:2:3",
+ "0258-1-8 1:12:32.358261+01:02:03",
+ ),
+ ("1000,1,1,0,0~2", "1000-01-01 00:00+2"),
+ ("2000,1,1,0,0~2", "2000-01-01 00:00+2"),
+ ("2000,1,1,0,0~12", "2000-01-01 00:00+12"),
+ ("2000,1,1,0,0~-12", "2000-01-01 00:00-12"),
+ ("2000,1,1,0,0~01:02:03", "2000-01-01 00:00+01:02:03"),
+ ("2000,1,1,0,0~-01:02:03", "2000-01-01 00:00-01:02:03"),
+ ("2000,12,31,23,59,59,999999~2", "2000-12-31 23:59:59.999999+2"),
+ (
+ "2034,02,03,23,34,27,951357~-4:27",
+ "2034-02-03 23:34:27.951357-04:27",
+ ),
+ ("2300,1,1,0,0,0,1~1", "2300-01-01 00:00:00.000001+1"),
+ ("3000,1,1,0,0~2", "3000-01-01 00:00+2"),
+ ("7000,1,1,0,0,0,1~-1:2:3", "7000-01-01 00:00:00.000001-01:02:03"),
+ ("max~2", "9999-12-31 23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_datetimetz(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(
+ f"""
+ select '{expr}'::timestamptz = %(val){fmt_in.value},
+ '{expr}', %(val){fmt_in.value}::text
+ """,
+ {"val": as_dt(val)},
+ )
+ ok, want, got = cur.fetchone()
+ assert ok, (want, got)
+
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_dump_datetimetz_datestyle(self, conn, datestyle_in):
+ tzinfo = dt.timezone(dt.timedelta(hours=2))
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = ISO, {datestyle_in}")
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(
+ "select 'epoch'::timestamptz + '1d 3h 4m 5.678s'::interval = %t",
+ (dt.datetime(1970, 1, 2, 5, 4, 5, 678000, tzinfo=tzinfo),),
+ )
+ assert cur.fetchone()[0] is True
+
+ load_datetimetz_samples = [
+ ("2000,1,1~2", "2000-01-01", "-02:00"),
+ ("2000,1,2,3,4,5,6~2", "2000-01-02 03:04:05.000006", "-02:00"),
+ ("2000,1,2,3,4,5,678~1", "2000-01-02 03:04:05.000678", "Europe/Rome"),
+ ("2000,7,2,3,4,5,678~2", "2000-07-02 03:04:05.000678", "Europe/Rome"),
+ ("2000,1,2,3,0,0,456789~2", "2000-01-02 03:00:00.456789", "-02:00"),
+ ("2000,1,2,3,0,0,456789~-2", "2000-01-02 03:00:00.456789", "+02:00"),
+ ("2000,12,31~2", "2000-12-31", "-02:00"),
+ ("1900,1,1~05:21:10", "1900-01-01", "Asia/Calcutta"),
+ ]
+
+ @crdb_skip_datestyle
+ @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples)
+ @pytest.mark.parametrize("datestyle_out", ["ISO"])
+ def test_load_datetimetz(self, conn, val, expr, timezone, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, DMY")
+ cur.execute(f"set timezone to '{timezone}'")
+ got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0]
+ assert got == as_dt(val)
+
+ @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples)
+ def test_load_datetimetz_binary(self, conn, val, expr, timezone):
+ cur = conn.cursor(binary=True)
+ cur.execute(f"set timezone to '{timezone}'")
+ got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0]
+ assert got == as_dt(val)
+
+ @pytest.mark.xfail # parse timezone names
+ @crdb_skip_datestyle
+ @pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")])
+ @pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"])
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_load_datetimetz_tzname(self, conn, val, expr, datestyle_in, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(f"select '{expr}'::timestamptz")
+ assert cur.fetchone()[0] == as_dt(val)
+
+ @pytest.mark.parametrize(
+ "tzname, expr, tzoff",
+ [
+ ("UTC", "2000-1-1", 0),
+ ("UTC", "2000-7-1", 0),
+ ("Europe/Rome", "2000-1-1", 3600),
+ ("Europe/Rome", "2000-7-1", 7200),
+ ("Europe/Rome", "1000-1-1", 2996),
+ pytest.param("NOSUCH0", "2000-1-1", 0, marks=crdb_skip_invalid_tz),
+ pytest.param("0", "2000-1-1", 0, marks=crdb_skip_invalid_tz),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_datetimetz_tz(self, conn, fmt_out, tzname, expr, tzoff):
+ conn.execute("select set_config('TimeZone', %s, true)", [tzname])
+ cur = conn.cursor(binary=fmt_out)
+ ts = cur.execute("select %s::timestamptz", [expr]).fetchone()[0]
+ assert ts.utcoffset().total_seconds() == tzoff
+
+ @pytest.mark.parametrize(
+ "val, type",
+ [
+ ("2000,1,2,3,4,5,6", "timestamp"),
+ ("2000,1,2,3,4,5,6~0", "timestamptz"),
+ ("2000,1,2,3,4,5,6~2", "timestamptz"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_datetime_tz_or_not_tz(self, conn, val, type, fmt_in):
+ val = as_dt(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value}
+ """,
+ [val, type, val],
+ )
+ rec = cur.fetchone()
+ assert rec[0] is True, type
+ assert rec[1] == val
+
+ @pytest.mark.crdb_skip("copy")
+ def test_load_copy(self, conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy(
+ """
+ copy (
+ select
+ '2000-01-01 01:02:03.123456-10:20'::timestamptz,
+ '11111111'::int4
+ ) to stdout
+ """
+ ) as copy:
+ copy.set_types(["timestamptz", "int4"])
+ rec = copy.read_row()
+
+ tz = dt.timezone(-dt.timedelta(hours=10, minutes=20))
+ want = dt.datetime(2000, 1, 1, 1, 2, 3, 123456, tzinfo=tz)
+ assert rec[0] == want
+ assert rec[1] == 11111111
+
+ overflow_samples = [
+ ("-infinity", "timestamp too small"),
+ ("1000-01-01 12:00+00 BC", "timestamp too small"),
+ ("10000-01-01 12:00+00", "timestamp too large"),
+ ("infinity", "timestamp too large"),
+ ]
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message(self, conn, datestyle_out, val, msg):
+ cur = conn.cursor()
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %s::timestamptz", (val,))
+ if datestyle_out == "ISO":
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+ else:
+ with pytest.raises(NotImplementedError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message_binary(self, conn, val, msg):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s::timestamptz", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @pytest.mark.parametrize(
+ "valname, tzval, tzname",
+ [
+ ("max", "-06", "America/Chicago"),
+ ("min", "+09:18:59", "Asia/Tokyo"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_max_with_timezone(self, conn, fmt_out, valname, tzval, tzname):
+ # This happens e.g. in Django when it caches forever.
+ # e.g. see Django test cache.tests.DBCacheTests.test_forever_timeout
+ val = getattr(dt.datetime, valname).replace(microsecond=0)
+ tz = dt.timezone(as_tzoffset(tzval))
+ want = val.replace(tzinfo=tz)
+
+ conn.execute("set timezone to '%s'" % tzname)
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::timestamptz", [str(val) + tzval])
+ got = cur.fetchone()[0]
+
+ assert got == want
+
+ extra = "1 day" if valname == "max" else "-1 day"
+ with pytest.raises(DataError):
+ cur.execute(
+ "select %s::timestamptz + %s::interval",
+ [str(val) + tzval, extra],
+ )
+ got = cur.fetchone()[0]
+
+
+class TestTime:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "00:00"),
+ ("10,20,30,40", "10:20:30.000040"),
+ ("max", "23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_time(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ select '{expr}'::time = %(val){fmt_in.value},
+ '{expr}'::time::text, %(val){fmt_in.value}::text
+ """,
+ {"val": as_time(val)},
+ )
+ ok, want, got = cur.fetchone()
+ assert ok, (got, want)
+
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "00:00"),
+ ("1,2", "01:02"),
+ ("10,20", "10:20"),
+ ("10,20,30", "10:20:30"),
+ ("10,20,30,40", "10:20:30.000040"),
+ ("max", "23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_time(self, conn, val, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select '{expr}'::time")
+ assert cur.fetchone()[0] == as_time(val)
+
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_time_24(self, conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select '24:00'::time")
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+
+class TestTimeTz:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min~-10", "00:00-10:00"),
+ ("min~+12", "00:00+12:00"),
+ ("10,20,30,40~-2", "10:20:30.000040-02:00"),
+ ("10,20,30,40~0", "10:20:30.000040Z"),
+ ("10,20,30,40~+2:30", "10:20:30.000040+02:30"),
+ ("max~-12", "23:59:59.999999-12:00"),
+ ("max~+12", "23:59:59.999999+12:00"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_timetz(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(f"select '{expr}'::timetz = %{fmt_in.value}", (as_time(val),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize(
+ "val, expr, timezone",
+ [
+ ("0,0~-12", "00:00", "12:00"),
+ ("0,0~12", "00:00", "-12:00"),
+ ("3,4,5,6~2", "03:04:05.000006", "-02:00"),
+ ("3,4,5,6~7:8", "03:04:05.000006", "-07:08"),
+ ("3,0,0,456789~2", "03:00:00.456789", "-02:00"),
+ ("3,0,0,456789~-2", "03:00:00.456789", "+02:00"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_timetz(self, conn, val, timezone, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"set timezone to '{timezone}'")
+ cur.execute(f"select '{expr}'::timetz")
+ assert cur.fetchone()[0] == as_time(val)
+
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_timetz_24(self, conn, fmt_out):
+ cur = conn.cursor()
+ cur.execute("select '24:00'::timetz")
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize(
+ "val, type",
+ [
+ ("3,4,5,6", "time"),
+ ("3,4,5,6~0", "timetz"),
+ ("3,4,5,6~2", "timetz"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_time_tz_or_not_tz(self, conn, val, type, fmt_in):
+ val = as_time(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value}
+ """,
+ [val, type, val],
+ )
+ rec = cur.fetchone()
+ assert rec[0] is True, type
+ assert rec[1] == val
+
+ @pytest.mark.crdb_skip("copy")
+ def test_load_copy(self, conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy(
+ """
+ copy (
+ select
+ '01:02:03.123456-10:20'::timetz,
+ '11111111'::int4
+ ) to stdout
+ """
+ ) as copy:
+ copy.set_types(["timetz", "int4"])
+ rec = copy.read_row()
+
+ tz = dt.timezone(-dt.timedelta(hours=10, minutes=20))
+ want = dt.time(1, 2, 3, 123456, tzinfo=tz)
+ assert rec[0] == want
+ assert rec[1] == 11111111
+
+
+class TestInterval:
+ dump_timedelta_samples = [
+ ("min", "-999999999 days"),
+ ("1d", "1 day"),
+ pytest.param("-1d", "-1 day", marks=crdb_skip_negative_interval),
+ ("1s", "1 s"),
+ pytest.param("-1s", "-1 s", marks=crdb_skip_negative_interval),
+ pytest.param("-1m", "-0.000001 s", marks=crdb_skip_negative_interval),
+ ("1m", "0.000001 s"),
+ ("max", "999999999 days 23:59:59.999999"),
+ ]
+
+ @pytest.mark.parametrize("val, expr", dump_timedelta_samples)
+ @pytest.mark.parametrize("intervalstyle", intervalstyles)
+ def test_dump_interval(self, conn, val, expr, intervalstyle):
+ cur = conn.cursor()
+ cur.execute(f"set IntervalStyle to '{intervalstyle}'")
+ cur.execute(f"select '{expr}'::interval = %t", (as_td(val),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize("val, expr", dump_timedelta_samples)
+ def test_dump_interval_binary(self, conn, val, expr):
+ cur = conn.cursor()
+ cur.execute(f"select '{expr}'::interval = %b", (as_td(val),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("1s", "1 sec"),
+ ("-1s", "-1 sec"),
+ ("60s", "1 min"),
+ ("3600s", "1 hour"),
+ ("1s,1000m", "1.001 sec"),
+ ("1s,1m", "1.000001 sec"),
+ ("1d", "1 day"),
+ ("-10d", "-10 day"),
+ ("1d,1s,1m", "1 day 1.000001 sec"),
+ ("-86399s,-999999m", "-23:59:59.999999"),
+ ("-3723s,-400000m", "-1:2:3.4"),
+ ("3723s,400000m", "1:2:3.4"),
+ ("86399s,999999m", "23:59:59.999999"),
+ ("30d", "30 day"),
+ ("365d", "1 year"),
+ ("-365d", "-1 year"),
+ ("-730d", "-2 years"),
+ ("1460d", "4 year"),
+ ("30d", "1 month"),
+ ("-30d", "-1 month"),
+ ("60d", "2 month"),
+ ("-90d", "-3 month"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_interval(self, conn, val, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select '{expr}'::interval")
+ assert cur.fetchone()[0] == as_td(val)
+
+ @crdb_skip_datestyle
+ @pytest.mark.xfail # weird interval outputs
+ @pytest.mark.parametrize("val, expr", [("1d,1s", "1 day 1 sec")])
+ @pytest.mark.parametrize(
+ "intervalstyle",
+ ["sql_standard", "postgres_verbose", "iso_8601"],
+ )
+ def test_load_interval_intervalstyle(self, conn, val, expr, intervalstyle):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set IntervalStyle to '{intervalstyle}'")
+ cur.execute(f"select '{expr}'::interval")
+ assert cur.fetchone()[0] == as_td(val)
+
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ @pytest.mark.parametrize("val", ["min", "max"])
+ def test_load_interval_overflow(self, conn, val, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(
+ "select %s + %s * '1s'::interval",
+ (as_td(val), -1 if val == "min" else 1),
+ )
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.crdb_skip("copy")
+ def test_load_copy(self, conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy(
+ """
+ copy (
+ select
+ '1 days +00:00:01.000001'::interval,
+ 'foo bar'::text
+ ) to stdout
+ """
+ ) as copy:
+ copy.set_types(["interval", "text"])
+ rec = copy.read_row()
+
+ want = dt.timedelta(days=1, seconds=1, microseconds=1)
+ assert rec[0] == want
+ assert rec[1] == "foo bar"
+
+
+#
+# Support
+#
+
+
+def as_date(s):
+ return dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s)
+
+
+def as_time(s):
+ if "~" in s:
+ s, off = s.split("~")
+ else:
+ off = None
+
+ if "," in s:
+ rv = dt.time(*map(int, s.split(","))) # type: ignore[arg-type]
+ else:
+ rv = getattr(dt.time, s)
+ if off:
+ rv = rv.replace(tzinfo=as_tzinfo(off))
+
+ return rv
+
+
+def as_dt(s):
+ if "~" not in s:
+ return as_naive_dt(s)
+
+ s, off = s.split("~")
+ rv = as_naive_dt(s)
+ off = as_tzoffset(off)
+ rv = (rv - off).replace(tzinfo=dt.timezone.utc)
+ return rv
+
+
+def as_naive_dt(s):
+ if "," in s:
+ rv = dt.datetime(*map(int, s.split(","))) # type: ignore[arg-type]
+ else:
+ rv = getattr(dt.datetime, s)
+
+ return rv
+
+
+def as_tzoffset(s):
+ if s.startswith("-"):
+ mul = -1
+ s = s[1:]
+ else:
+ mul = 1
+
+ fields = ("hours", "minutes", "seconds")
+ return mul * dt.timedelta(**dict(zip(fields, map(int, s.split(":")))))
+
+
+def as_tzinfo(s):
+ off = as_tzoffset(s)
+ return dt.timezone(off)
+
+
+def as_td(s):
+ if s in ("min", "max"):
+ return getattr(dt.timedelta, s)
+
+ suffixes = {"d": "days", "s": "seconds", "m": "microseconds"}
+ kwargs = {}
+ for part in s.split(","):
+ kwargs[suffixes[part[-1]]] = int(part[:-1])
+
+ return dt.timedelta(**kwargs)
diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py
new file mode 100644
index 0000000..8dfb6d4
--- /dev/null
+++ b/tests/types/test_enum.py
@@ -0,0 +1,363 @@
+from enum import Enum, auto
+
+import pytest
+
+from psycopg import pq, sql, errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types import TypeInfo
+from psycopg.types.enum import EnumInfo, register_enum
+
+from ..fix_crdb import crdb_encoding
+
+
+class PureTestEnum(Enum):
+ FOO = auto()
+ BAR = auto()
+ BAZ = auto()
+
+
+class StrTestEnum(str, Enum):
+ ONE = "ONE"
+ TWO = "TWO"
+ THREE = "THREE"
+
+
+NonAsciiEnum = Enum(
+ "NonAsciiEnum",
+ {"X\xe0": "x\xe0", "X\xe1": "x\xe1", "COMMA": "foo,bar"},
+ type=str,
+)
+
+
+class IntTestEnum(int, Enum):
+ ONE = 1
+ TWO = 2
+ THREE = 3
+
+
+enum_cases = [PureTestEnum, StrTestEnum, IntTestEnum]
+encodings = ["utf8", crdb_encoding("latin1")]
+
+
+@pytest.fixture(scope="session", autouse=True)
+def make_test_enums(request, svcconn):
+ for enum in enum_cases + [NonAsciiEnum]:
+ ensure_enum(enum, svcconn)
+
+
+def ensure_enum(enum, conn):
+ name = enum.__name__.lower()
+ labels = list(enum.__members__)
+ conn.execute(
+ sql.SQL(
+ """
+ drop type if exists {name};
+ create type {name} as enum ({labels});
+ """
+ ).format(name=sql.Identifier(name), labels=sql.SQL(",").join(labels))
+ )
+ return name, enum, labels
+
+
+def test_fetch_info(conn):
+ info = EnumInfo.fetch(conn, "StrTestEnum")
+ assert info.name == "strtestenum"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.labels) == len(StrTestEnum)
+ assert info.labels == list(StrTestEnum.__members__)
+
+
+@pytest.mark.asyncio
+async def test_fetch_info_async(aconn):
+ info = await EnumInfo.fetch(aconn, "PureTestEnum")
+ assert info.name == "puretestenum"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.labels) == len(PureTestEnum)
+ assert info.labels == list(PureTestEnum.__members__)
+
+
+def test_register_makes_a_type(conn):
+ info = EnumInfo.fetch(conn, "IntTestEnum")
+ assert info
+ assert info.enum is None
+ register_enum(info, context=conn)
+ assert info.enum is not None
+ assert [e.name for e in info.enum] == list(IntTestEnum.__members__)
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum=enum)
+
+ for label in info.labels:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == enum[label]
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out):
+ enum = NonAsciiEnum
+ conn.execute(f"set client_encoding to {encoding}")
+
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum=enum)
+
+ for label in info.labels:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == enum[label]
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+ conn.execute("set client_encoding to sql_ascii")
+
+ for label in info.labels:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == enum[label]
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_dumper(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ for item in enum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == item
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out):
+ enum = NonAsciiEnum
+ conn.execute(f"set client_encoding to {encoding}")
+
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ for item in enum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == item
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+ conn.execute("set client_encoding to sql_ascii")
+
+ for item in enum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == item
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_dumper(conn, enum, fmt_in, fmt_out):
+ for item in enum:
+ if enum is PureTestEnum:
+ want = item.name
+ else:
+ want = item.value
+
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == want
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
+ for item in NonAsciiEnum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item.value], binary=fmt_out)
+ assert cur.fetchone()[0] == item.value
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_loader(conn, enum, fmt_in, fmt_out):
+ for label in enum.__members__:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out
+ )
+ want = enum[label].name
+ if fmt_out == pq.Format.BINARY:
+ want = want.encode()
+ assert cur.fetchone()[0] == want
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
+
+ for label in NonAsciiEnum.__members__:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::nonasciienum", [label], binary=fmt_out
+ )
+ if fmt_out == pq.Format.TEXT:
+ assert cur.fetchone()[0] == label
+ else:
+ assert cur.fetchone()[0] == label.encode(encoding)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader(conn, fmt_in, fmt_out):
+ enum = PureTestEnum
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ labels = list(enum.__members__)
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == list(enum)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_dumper(conn, fmt_in, fmt_out):
+ enum = StrTestEnum
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ cur = conn.execute(f"select %{fmt_in.value}::text[]", [list(enum)], binary=fmt_out)
+ assert cur.fetchone()[0] == list(enum.__members__)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_array_loader(conn, fmt_in, fmt_out):
+ enum = IntTestEnum
+ info = TypeInfo.fetch(conn, enum.__name__)
+ info.register(conn)
+ labels = list(enum.__members__)
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out
+ )
+ if fmt_out == pq.Format.TEXT:
+ assert cur.fetchone()[0] == labels
+ else:
+ assert cur.fetchone()[0] == [item.encode() for item in labels]
+
+
+def test_enum_error(conn):
+ conn.autocommit = True
+
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, StrTestEnum)
+
+ with pytest.raises(e.DataError):
+ conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone()
+ with pytest.raises(e.DataError):
+ conn.execute("select 'BAR'::puretestenum").fetchone()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+ "mapping",
+ [
+ {StrTestEnum.ONE: "FOO", StrTestEnum.TWO: "BAR", StrTestEnum.THREE: "BAZ"},
+ [
+ (StrTestEnum.ONE, "FOO"),
+ (StrTestEnum.TWO, "BAR"),
+ (StrTestEnum.THREE, "BAZ"),
+ ],
+ ],
+)
+def test_remap(conn, fmt_in, fmt_out, mapping):
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, StrTestEnum, mapping=mapping)
+
+ for member, label in [("ONE", "FOO"), ("TWO", "BAR"), ("THREE", "BAZ")]:
+ cur = conn.execute(f"select %{fmt_in.value}::text", [StrTestEnum[member]])
+ assert cur.fetchone()[0] == label
+ cur = conn.execute(f"select '{label}'::puretestenum", binary=fmt_out)
+ assert cur.fetchone()[0] is StrTestEnum[member]
+
+
+def test_remap_rename(conn):
+ enum = Enum("RenamedEnum", "FOO BAR QUX")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, enum, mapping={enum.QUX: "BAZ"})
+
+ for member, label in [("FOO", "FOO"), ("BAR", "BAR"), ("QUX", "BAZ")]:
+ cur = conn.execute("select %s::text", [enum[member]])
+ assert cur.fetchone()[0] == label
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_python(conn):
+ enum = Enum("LargerEnum", "FOO BAR BAZ QUX QUUX QUUUX")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ mapping = {enum[m]: "BAZ" for m in ["QUX", "QUUX", "QUUUX"]}
+ register_enum(info, conn, enum, mapping=mapping)
+
+ for member, label in [("FOO", "FOO"), ("BAZ", "BAZ"), ("QUUUX", "BAZ")]:
+ cur = conn.execute("select %s::text", [enum[member]])
+ assert cur.fetchone()[0] == label
+
+ for member, label in [("FOO", "FOO"), ("QUUUX", "BAZ")]:
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_postgres(conn):
+ enum = Enum("SmallerEnum", "FOO")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ mapping = [(enum.FOO, "BAR"), (enum.FOO, "BAZ")]
+ register_enum(info, conn, enum, mapping=mapping)
+
+ cur = conn.execute("select %s::text", [enum.FOO])
+ assert cur.fetchone()[0] == "BAZ"
+
+ for label in PureTestEnum.__members__:
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum.FOO
+
+
+def test_remap_by_value(conn):
+ enum = Enum( # type: ignore
+ "ByValue",
+ {m.lower(): m for m in PureTestEnum.__members__},
+ )
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, enum, mapping={m: m.value for m in enum})
+
+ for label in PureTestEnum.__members__:
+ cur = conn.execute("select %s::text", [enum[label.lower()]])
+ assert cur.fetchone()[0] == label
+
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[label.lower()]
diff --git a/tests/types/test_hstore.py b/tests/types/test_hstore.py
new file mode 100644
index 0000000..5142d58
--- /dev/null
+++ b/tests/types/test_hstore.py
@@ -0,0 +1,107 @@
+import pytest
+
+import psycopg
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import HstoreLoader, register_hstore
+
+pytestmark = pytest.mark.crdb_skip("hstore")
+
+
+@pytest.mark.parametrize(
+ "s, d",
+ [
+ ("", {}),
+ ('"a"=>"1", "b"=>"2"', {"a": "1", "b": "2"}),
+ ('"a" => "1" , "b" => "2"', {"a": "1", "b": "2"}),
+ ('"a"=>NULL, "b"=>"2"', {"a": None, "b": "2"}),
+ (r'"a"=>"\"", "\""=>"2"', {"a": '"', '"': "2"}),
+ ('"a"=>"\'", "\'"=>"2"', {"a": "'", "'": "2"}),
+ ('"a"=>"1", "b"=>NULL', {"a": "1", "b": None}),
+ (r'"a\\"=>"1"', {"a\\": "1"}),
+ (r'"a\""=>"1"', {'a"': "1"}),
+ (r'"a\\\""=>"1"', {r"a\"": "1"}),
+ (r'"a\\\\\""=>"1"', {r'a\\"': "1"}),
+ ('"\xe8"=>"\xe0"', {"\xe8": "\xe0"}),
+ ],
+)
+def test_parse_ok(s, d):
+ loader = HstoreLoader(0, None)
+ assert loader.load(s.encode()) == d
+
+
+@pytest.mark.parametrize(
+ "s",
+ [
+ "a",
+ '"a"',
+ r'"a\\""=>"1"',
+ r'"a\\\\""=>"1"',
+ '"a=>"1"',
+ '"a"=>"1", "b"=>NUL',
+ ],
+)
+def test_parse_bad(s):
+ with pytest.raises(psycopg.DataError):
+ loader = HstoreLoader(0, None)
+ loader.load(s.encode())
+
+
+def test_register_conn(hstore, conn):
+ info = TypeInfo.fetch(conn, "hstore")
+ register_hstore(info, conn)
+ assert conn.adapters.types[info.oid].name == "hstore"
+
+ cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_curs(hstore, conn):
+ info = TypeInfo.fetch(conn, "hstore")
+ cur = conn.cursor()
+ register_hstore(info, cur)
+ assert conn.adapters.types.get(info.oid) is None
+ assert cur.adapters.types[info.oid].name == "hstore"
+
+ cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_globally(conn_cls, hstore, dsn, svcconn, global_adapters):
+ info = TypeInfo.fetch(svcconn, "hstore")
+ register_hstore(info)
+ assert psycopg.adapters.types[info.oid].name == "hstore"
+
+ assert svcconn.adapters.types.get(info.oid) is None
+ conn = conn_cls.connect(dsn)
+ assert conn.adapters.types[info.oid].name == "hstore"
+
+ cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+ conn.close()
+
+
+ab = list(map(chr, range(32, 128)))
+samp = [
+ {},
+ {"a": "b", "c": None},
+ dict(zip(ab, ab)),
+ {"".join(ab): "".join(ab)},
+]
+
+
+@pytest.mark.parametrize("d", samp)
+def test_roundtrip(hstore, conn, d):
+ register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
+ d1 = conn.execute("select %s", [d]).fetchone()[0]
+ assert d == d1
+
+
+def test_roundtrip_array(hstore, conn):
+ register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
+ samp1 = conn.execute("select %s", (samp,)).fetchone()[0]
+ assert samp1 == samp
+
+
+def test_no_info_error(conn):
+ with pytest.raises(TypeError, match="hstore.*extension"):
+ register_hstore(None, conn) # type: ignore[arg-type]
diff --git a/tests/types/test_json.py b/tests/types/test_json.py
new file mode 100644
index 0000000..50e8ce3
--- /dev/null
+++ b/tests/types/test_json.py
@@ -0,0 +1,182 @@
+import json
+from copy import deepcopy
+
+import pytest
+
+import psycopg.types
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat
+from psycopg.types.json import set_json_dumps, set_json_loads
+
+samples = [
+ "null",
+ "true",
+ '"te\'xt"',
+ '"\\u00e0\\u20ac"',
+ "123",
+ "123.45",
+ '["a", 100]',
+ '{"a": 100}',
+]
+
+
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_wrapper_regtype(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ cur = conn.cursor()
+ cur.execute(
+ f"select pg_typeof(%{fmt_in.value})::regtype = %s::regtype",
+ (wrapper([]), wrapper.__name__.lower()),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump(conn, val, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = json.loads(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"select %{fmt_in.value}::text = %s::{wrapper.__name__.lower()}::text",
+ (wrapper(obj), val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.crdb_skip("json array")
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_array_dump(conn, val, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = json.loads(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"select %{fmt_in.value}::text = array[%s::{wrapper.__name__.lower()}]::text",
+ ([wrapper(obj)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load(conn, val, jtype, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select %s::{jtype}", (val,))
+ assert cur.fetchone()[0] == json.loads(val)
+
+
+@pytest.mark.crdb_skip("json array")
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_array(conn, val, jtype, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select array[%s::{jtype}]", (val,))
+ assert cur.fetchone()[0] == [json.loads(val)]
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_copy(conn, val, jtype, fmt_out):
+ cur = conn.cursor()
+ stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format(
+ val, sql.Identifier(jtype), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([jtype])
+ (got,) = copy.read_row()
+
+ assert got == json.loads(val)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+def test_dump_customise(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = {"foo": "bar"}
+ cur = conn.cursor()
+
+ set_json_dumps(my_dumps)
+ try:
+ cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj),))
+ assert cur.fetchone()[0] is True
+ finally:
+ set_json_dumps(json.dumps)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+def test_dump_customise_context(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = {"foo": "bar"}
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+
+ set_json_dumps(my_dumps, cur2)
+ cur1.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),))
+ assert cur1.fetchone()[0] is None
+ cur2.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),))
+ assert cur2.fetchone()[0] == "qux"
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+def test_dump_customise_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = {"foo": "bar"}
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj, my_dumps),))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("binary", [True, False])
+@pytest.mark.parametrize("pgtype", ["json", "jsonb"])
+def test_load_customise(conn, binary, pgtype):
+ cur = conn.cursor(binary=binary)
+
+ set_json_loads(my_loads)
+ try:
+ cur.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+ obj = cur.fetchone()[0]
+ assert obj["foo"] == "bar"
+ assert obj["answer"] == 42
+ finally:
+ set_json_loads(json.loads)
+
+
+@pytest.mark.parametrize("binary", [True, False])
+@pytest.mark.parametrize("pgtype", ["json", "jsonb"])
+def test_load_customise_context(conn, binary, pgtype):
+ cur1 = conn.cursor(binary=binary)
+ cur2 = conn.cursor(binary=binary)
+
+ set_json_loads(my_loads, cur2)
+ cur1.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+ got = cur1.fetchone()[0]
+ assert got["foo"] == "bar"
+ assert "answer" not in got
+
+ cur2.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+ got = cur2.fetchone()[0]
+ assert got["foo"] == "bar"
+ assert got["answer"] == 42
+
+
+def my_dumps(obj):
+ obj = deepcopy(obj)
+ obj["baz"] = "qux"
+ return json.dumps(obj)
+
+
+def my_loads(data):
+ obj = json.loads(data)
+ obj["answer"] = 42
+ return obj
diff --git a/tests/types/test_multirange.py b/tests/types/test_multirange.py
new file mode 100644
index 0000000..2ab5152
--- /dev/null
+++ b/tests/types/test_multirange.py
@@ -0,0 +1,434 @@
+import pickle
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg import pq, sql
+from psycopg import errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types.range import Range
+from psycopg.types import multirange
+from psycopg.types.multirange import Multirange, MultirangeInfo
+from psycopg.types.multirange import register_multirange
+
+from ..utils import eur
+from .test_range import create_test_range
+
+pytestmark = [
+ pytest.mark.pg(">= 14"),
+ pytest.mark.crdb_skip("range"),
+]
+
+
+class TestMultirangeObject:
+ def test_empty(self):
+ mr = Multirange[int]()
+ assert not mr
+ assert len(mr) == 0
+
+ mr = Multirange([])
+ assert not mr
+ assert len(mr) == 0
+
+ def test_sequence(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ assert mr
+ assert len(mr) == 3
+ assert mr[2] == Range(50, 60)
+ assert mr[-2] == Range(30, 40)
+
+ def test_bad_type(self):
+ with pytest.raises(TypeError):
+ Multirange(Range(10, 20)) # type: ignore[arg-type]
+
+ with pytest.raises(TypeError):
+ Multirange([10]) # type: ignore[arg-type]
+
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+
+ with pytest.raises(TypeError):
+ mr[0] = "foo" # type: ignore[call-overload]
+
+ with pytest.raises(TypeError):
+ mr[0:1] = "foo" # type: ignore[assignment]
+
+ with pytest.raises(TypeError):
+ mr[0:1] = ["foo"] # type: ignore[list-item]
+
+ with pytest.raises(TypeError):
+ mr.insert(0, "foo") # type: ignore[arg-type]
+
+ def test_setitem(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ mr[1] = Range(31, 41)
+ assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)])
+
+ def test_setitem_slice(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ mr[1:3] = [Range(31, 41), Range(51, 61)]
+ assert mr == Multirange([Range(10, 20), Range(31, 41), Range(51, 61)])
+
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ with pytest.raises(TypeError, match="can only assign an iterable"):
+ mr[1:3] = Range(31, 41) # type: ignore[call-overload]
+
+ mr[1:3] = [Range(31, 41)]
+ assert mr == Multirange([Range(10, 20), Range(31, 41)])
+
+ def test_delitem(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ del mr[1]
+ assert mr == Multirange([Range(10, 20), Range(50, 60)])
+
+ del mr[-2]
+ assert mr == Multirange([Range(50, 60)])
+
+ def test_insert(self):
+ mr = Multirange([Range(10, 20), Range(50, 60)])
+ mr.insert(1, Range(31, 41))
+ assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)])
+
+ def test_relations(self):
+ mr1 = Multirange([Range(10, 20), Range(30, 40)])
+ mr2 = Multirange([Range(11, 20), Range(30, 40)])
+ mr3 = Multirange([Range(9, 20), Range(30, 40)])
+ assert mr1 <= mr1
+ assert not mr1 < mr1
+ assert mr1 >= mr1
+ assert not mr1 > mr1
+ assert mr1 < mr2
+ assert mr1 <= mr2
+ assert mr1 > mr3
+ assert mr1 >= mr3
+ assert mr1 != mr2
+ assert not mr1 == mr2
+
+ def test_pickling(self):
+ r = Multirange([Range(0, 4)])
+ assert pickle.loads(pickle.dumps(r)) == r
+
+ def test_str(self):
+ mr = Multirange([Range(10, 20), Range(30, 40)])
+ assert str(mr) == "{[10, 20), [30, 40)}"
+
+ def test_repr(self):
+ mr = Multirange([Range(10, 20), Range(30, 40)])
+ expected = "Multirange([Range(10, 20, '[)'), Range(30, 40, '[)')])"
+ assert repr(mr) == expected
+
+
+tzinfo = dt.timezone(dt.timedelta(hours=2))
+
+samples = [
+ ("int4multirange", [Range(None, None, "()")]),
+ ("int4multirange", [Range(10, 20), Range(30, 40)]),
+ ("int8multirange", [Range(None, None, "()")]),
+ ("int8multirange", [Range(10, 20), Range(30, 40)]),
+ (
+ "nummultirange",
+ [
+ Range(None, Decimal(-100)),
+ Range(Decimal(100), Decimal("100.123")),
+ ],
+ ),
+ (
+ "datemultirange",
+ [Range(dt.date(2000, 1, 1), dt.date(2020, 1, 1))],
+ ),
+ (
+ "tsmultirange",
+ [
+ Range(
+ dt.datetime(2000, 1, 1, 00, 00),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999),
+ )
+ ],
+ ),
+ (
+ "tstzmultirange",
+ [
+ Range(
+ dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+ ),
+ Range(
+ dt.datetime(2030, 1, 1, 00, 00, tzinfo=tzinfo),
+ dt.datetime(2040, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+ ),
+ ],
+ ),
+]
+
+mr_names = """
+ int4multirange int8multirange nummultirange
+ datemultirange tsmultirange tstzmultirange""".split()
+
+mr_classes = """
+ Int4Multirange Int8Multirange NumericMultirange
+ DateMultirange TimestampMultirange TimestamptzMultirange""".split()
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty(conn, pgtype, fmt_in):
+ mr = Multirange() # type: ignore[var-annotated]
+ cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in.value}", (mr,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in):
+ dumper = getattr(multirange, wrapper + "Dumper")
+ wrapper = getattr(multirange, wrapper)
+ mr = wrapper()
+ rec = conn.execute(
+ f"""
+ select '{{}}' = %(mr){fmt_in.value},
+ %(mr){fmt_in.value}::text,
+ pg_typeof(%(mr){fmt_in.value})::oid
+ """,
+ {"mr": mr},
+ ).fetchone()
+ assert rec[0] is True, rec[1]
+ assert rec[2] == dumper.oid
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize(
+ "fmt_in",
+ [
+ PyFormat.AUTO,
+ PyFormat.TEXT,
+ # There are many ways to work around this (use text, use a cast on the
+ # placeholder, use specific Range subclasses).
+ pytest.param(
+ PyFormat.BINARY,
+ marks=pytest.mark.xfail(
+ reason="can't dump array of untypes binary multirange without cast"
+ ),
+ ),
+ ],
+)
+def test_dump_builtin_array(conn, pgtype, fmt_in):
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in.value}",
+ ([mr1, mr2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in):
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"""
+ select array['{{}}'::{pgtype},
+ '{{(,)}}'::{pgtype}] = %{fmt_in.value}::{pgtype}[]
+ """,
+ ([mr1, mr2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(multirange, wrapper)
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in.value}""", ([mr1, mr2],)
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype, ranges", samples)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_multirange(conn, pgtype, ranges, fmt_in):
+ mr = Multirange(ranges)
+ rname = pgtype.replace("multi", "")
+ phs = ", ".join([f"%s::{rname}"] * len(ranges))
+ cur = conn.execute(f"select {pgtype}({phs}) = %{fmt_in.value}", ranges + [mr])
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_empty(conn, pgtype, fmt_out):
+ mr = Multirange() # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select '{{}}'::{pgtype}").fetchone()
+ assert type(got) is Multirange
+ assert got == mr
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_array(conn, pgtype, fmt_out):
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(
+ f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}]"
+ ).fetchone()
+ assert got == [mr1, mr2]
+
+
+@pytest.mark.parametrize("pgtype, ranges", samples)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_range(conn, pgtype, ranges, fmt_out):
+ mr = Multirange(ranges)
+ rname = pgtype.replace("multi", "")
+ phs = ", ".join([f"%s::{rname}"] * len(ranges))
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select {pgtype}({phs})", ranges)
+ assert cur.fetchone()[0] == mr
+
+
+@pytest.mark.parametrize(
+ "min, max, bounds",
+ [
+ ("2000,1,1", "2001,1,1", "[)"),
+ ("2000,1,1", None, "[)"),
+ (None, "2001,1,1", "()"),
+ (None, None, "()"),
+ (None, None, "empty"),
+ ],
+)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in(conn, min, max, bounds, format):
+ cur = conn.cursor()
+ cur.execute("create table copymr (id serial primary key, mr datemultirange)")
+
+ if bounds != "empty":
+ min = dt.date(*map(int, min.split(","))) if min else None
+ max = dt.date(*map(int, max.split(","))) if max else None
+ r = Range[dt.date](min, max, bounds)
+ else:
+ r = Range(empty=True)
+
+ mr = Multirange([r])
+ try:
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
+ copy.write_row([mr])
+ except e.InternalError_:
+ if not min and not max and format == pq.Format.BINARY:
+ pytest.xfail("TODO: add annotation to dump multirange with no type info")
+ else:
+ raise
+
+ rec = cur.execute("select mr from copymr order by id").fetchone()
+ if not r.isempty:
+ assert rec[0] == mr
+ else:
+ assert rec[0] == Multirange()
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_wrappers(conn, wrapper, format):
+ cur = conn.cursor()
+ cur.execute("create table copymr (id serial primary key, mr datemultirange)")
+
+ mr = getattr(multirange, wrapper)()
+
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
+ copy.write_row([mr])
+
+ rec = cur.execute("select mr from copymr order by id").fetchone()
+ assert rec[0] == mr
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_set_type(conn, pgtype, format):
+ cur = conn.cursor()
+ cur.execute(f"create table copymr (id serial primary key, mr {pgtype})")
+
+ mr = Multirange() # type: ignore[var-annotated]
+
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
+ copy.set_types([pgtype])
+ copy.write_row([mr])
+
+ rec = cur.execute("select mr from copymr order by id").fetchone()
+ assert rec[0] == mr
+
+
+@pytest.fixture(scope="session")
+def testmr(svcconn):
+ create_test_range(svcconn)
+
+
+fetch_cases = [
+ ("testmultirange", "text"),
+ ("testschema.testmultirange", "float8"),
+ (sql.Identifier("testmultirange"), "text"),
+ (sql.Identifier("testschema", "testmultirange"), "float8"),
+]
+
+
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+def test_fetch_info(conn, testmr, name, subtype):
+ info = MultirangeInfo.fetch(conn, name)
+ assert info.name == "testmultirange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == conn.adapters.types[subtype].oid
+
+
+def test_fetch_info_not_found(conn):
+ assert MultirangeInfo.fetch(conn, "nosuchrange") is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+async def test_fetch_info_async(aconn, testmr, name, subtype): # noqa: F811
+ info = await MultirangeInfo.fetch(aconn, name)
+ assert info.name == "testmultirange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == aconn.adapters.types[subtype].oid
+
+
+@pytest.mark.asyncio
+async def test_fetch_info_not_found_async(aconn):
+ assert await MultirangeInfo.fetch(aconn, "nosuchrange") is None
+
+
+def test_dump_custom_empty(conn, testmr):
+ info = MultirangeInfo.fetch(conn, "testmultirange")
+ register_multirange(info, conn)
+
+ r = Multirange() # type: ignore[var-annotated]
+ cur = conn.execute("select '{}'::testmultirange = %s", (r,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_custom_empty(conn, testmr, fmt_out):
+ info = MultirangeInfo.fetch(conn, "testmultirange")
+ register_multirange(info, conn)
+
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute("select '{}'::testmultirange").fetchone()
+ assert isinstance(got, Multirange)
+ assert not got
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}"])
+def test_literal_invalid_name(conn, name):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(f'create type "{name}" as range (subtype = text)')
+ info = MultirangeInfo.fetch(conn, f'"{name}_multirange"')
+ register_multirange(info, conn)
+ obj = Multirange([Range("a", "z", "[]")])
+ assert sql.Literal(obj).as_string(conn) == f"'{{[a,z]}}'::\"{name}_multirange\""
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ assert cur.fetchone()[0] == obj
diff --git a/tests/types/test_net.py b/tests/types/test_net.py
new file mode 100644
index 0000000..8739398
--- /dev/null
+++ b/tests/types/test_net.py
@@ -0,0 +1,135 @@
+import ipaddress
+
+import pytest
+
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat
+
+crdb_skip_cidr = pytest.mark.crdb_skip("cidr")
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("val", ["192.168.0.1", "2001:db8::"])
+def test_address_dump(conn, fmt_in, val):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = %s::inet", (ipaddress.ip_address(val), val))
+ assert cur.fetchone()[0] is True
+ cur.execute(
+ f"select %{fmt_in.value} = array[null, %s]::inet[]",
+ ([None, ipaddress.ip_interface(val)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/128"])
+def test_interface_dump(conn, fmt_in, val):
+ cur = conn.cursor()
+ rec = cur.execute(
+ f"select %(val){fmt_in.value} = %(repr)s::inet,"
+ f" %(val){fmt_in.value}, %(repr)s::inet",
+ {"val": ipaddress.ip_interface(val), "repr": val},
+ ).fetchone()
+ assert rec[0] is True, f"{rec[1]} != {rec[2]}"
+ cur.execute(
+ f"select %{fmt_in.value} = array[null, %s]::inet[]",
+ ([None, ipaddress.ip_interface(val)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@crdb_skip_cidr
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
+def test_network_dump(conn, fmt_in, val):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = %s::cidr", (ipaddress.ip_network(val), val))
+ assert cur.fetchone()[0] is True
+ cur.execute(
+ f"select %{fmt_in.value} = array[NULL, %s]::cidr[]",
+ ([None, ipaddress.ip_network(val)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@crdb_skip_cidr
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_network_mixed_size_array(conn, fmt_in):
+ val = [
+ ipaddress.IPv4Network("192.168.0.1/32"),
+ ipaddress.IPv6Network("::1/128"),
+ ]
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}", (val,))
+ got = cur.fetchone()[0]
+ assert val == got
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"])
+def test_inet_load_address(conn, fmt_out, val):
+ addr = ipaddress.ip_address(val.split("/", 1)[0])
+ cur = conn.cursor(binary=fmt_out)
+
+ cur.execute("select %s::inet", (val,))
+ assert cur.fetchone()[0] == addr
+
+ cur.execute("select array[null, %s::inet]", (val,))
+ assert cur.fetchone()[0] == [None, addr]
+
+ stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["inet"])
+ (got,) = copy.read_row()
+
+ assert got == addr
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"])
+def test_inet_load_network(conn, fmt_out, val):
+ pyval = ipaddress.ip_interface(val)
+ cur = conn.cursor(binary=fmt_out)
+
+ cur.execute("select %s::inet", (val,))
+ assert cur.fetchone()[0] == pyval
+
+ cur.execute("select array[null, %s::inet]", (val,))
+ assert cur.fetchone()[0] == [None, pyval]
+
+ stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["inet"])
+ (got,) = copy.read_row()
+
+ assert got == pyval
+
+
+@crdb_skip_cidr
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
+def test_cidr_load(conn, fmt_out, val):
+ pyval = ipaddress.ip_network(val)
+ cur = conn.cursor(binary=fmt_out)
+
+ cur.execute("select %s::cidr", (val,))
+ assert cur.fetchone()[0] == pyval
+
+ cur.execute("select array[null, %s::cidr]", (val,))
+ assert cur.fetchone()[0] == [None, pyval]
+
+ stmt = sql.SQL("copy (select {}::cidr) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["cidr"])
+ (got,) = copy.read_row()
+
+ assert got == pyval
diff --git a/tests/types/test_none.py b/tests/types/test_none.py
new file mode 100644
index 0000000..4c008fd
--- /dev/null
+++ b/tests/types/test_none.py
@@ -0,0 +1,12 @@
+from psycopg import sql
+from psycopg.adapt import Transformer, PyFormat
+
+
+def test_quote_none(conn):
+
+ tx = Transformer()
+ assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL"
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None)))
+ assert cur.fetchone()[0] is None
diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py
new file mode 100644
index 0000000..a27bc84
--- /dev/null
+++ b/tests/types/test_numeric.py
@@ -0,0 +1,625 @@
+import enum
+from decimal import Decimal
+from math import isnan, isinf, exp
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import Transformer, PyFormat
+from psycopg.types.numeric import FloatLoader
+
+from ..fix_crdb import is_crdb
+
+#
+# Tests with int
+#
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::int"),
+ (1, "'1'::int"),
+ (-1, "'-1'::int"),
+ (42, "'42'::smallint"),
+ (-42, "'-42'::smallint"),
+ (int(2**63 - 1), "'9223372036854775807'::bigint"),
+ (int(-(2**63)), "'-9223372036854775808'::bigint"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_int(conn, val, expr, fmt_in):
+ assert isinstance(val, int)
+ cur = conn.cursor()
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::smallint"),
+ (1, "'1'::smallint"),
+ (-1, "'-1'::smallint"),
+ (42, "'42'::smallint"),
+ (-42, "'-42'::smallint"),
+ (int(2**15 - 1), f"'{2 ** 15 - 1}'::smallint"),
+ (int(-(2**15)), f"'{-2 ** 15}'::smallint"),
+ (int(2**15), f"'{2 ** 15}'::integer"),
+ (int(-(2**15) - 1), f"'{-2 ** 15 - 1}'::integer"),
+ (int(2**31 - 1), f"'{2 ** 31 - 1}'::integer"),
+ (int(-(2**31)), f"'{-2 ** 31}'::integer"),
+ (int(2**31), f"'{2 ** 31}'::bigint"),
+ (int(-(2**31) - 1), f"'{-2 ** 31 - 1}'::bigint"),
+ (int(2**63 - 1), f"'{2 ** 63 - 1}'::bigint"),
+ (int(-(2**63)), f"'{-2 ** 63}'::bigint"),
+ (int(2**63), f"'{2 ** 63}'::numeric"),
+ (int(-(2**63) - 1), f"'{-2 ** 63 - 1}'::numeric"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_int_subtypes(conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+ cur.execute(
+ f"select {expr} = %(v){fmt_in.value}, {expr}::text, %(v){fmt_in.value}::text",
+ {"v": val},
+ )
+ ok, want, got = cur.fetchone()
+ assert got == want
+ assert ok
+
+
+class MyEnum(enum.IntEnum):
+ foo = 42
+
+
+class MyMixinEnum(enum.IntEnum):
+ foo = 42000000
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("enum", [MyEnum, MyMixinEnum])
+def test_dump_enum(conn, fmt_in, enum):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}", (enum.foo,))
+ (res,) = cur.fetchone()
+ assert res == enum.foo.value
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, b"0"),
+ (1, b"1"),
+ (-1, b" -1"),
+ (42, b"42"),
+ (-42, b" -42"),
+ (int(2**63 - 1), b"9223372036854775807"),
+ (int(-(2**63)), b" -9223372036854775808"),
+ (int(2**63), b"9223372036854775808"),
+ (int(-(2**63 + 1)), b" -9223372036854775809"),
+ (int(2**100), b"1267650600228229401496703205376"),
+ (int(-(2**100)), b" -1267650600228229401496703205376"),
+ ],
+)
+def test_quote_int(conn, val, expr):
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ assert cur.fetchone() == (val, -val)
+
+
+@pytest.mark.parametrize(
+ "val, pgtype, want",
+ [
+ ("0", "integer", 0),
+ ("1", "integer", 1),
+ ("-1", "integer", -1),
+ ("0", "int2", 0),
+ ("0", "int4", 0),
+ ("0", "int8", 0),
+ ("0", "integer", 0),
+ ("0", "oid", 0),
+ # bounds
+ ("-32768", "smallint", -32768),
+ ("+32767", "smallint", 32767),
+ ("-2147483648", "integer", -2147483648),
+ ("+2147483647", "integer", 2147483647),
+ ("-9223372036854775808", "bigint", -9223372036854775808),
+ ("9223372036854775807", "bigint", 9223372036854775807),
+ ("4294967295", "oid", 4294967295),
+ ],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_int(conn, val, pgtype, want, fmt_out):
+ if pgtype == "integer" and is_crdb(conn):
+ pgtype = "int4" # "integer" is "int8" on crdb
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select %s::{pgtype}", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid
+ result = cur.fetchone()[0]
+ assert result == want
+ assert type(result) is type(want)
+
+ # arrays work too
+ cur.execute(f"select array[%s::{pgtype}]", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid
+ result = cur.fetchone()[0]
+ assert result == [want]
+ assert type(result[0]) is type(want)
+
+
+#
+# Tests with float
+#
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0.0, "'0'"),
+ (1.0, "'1'"),
+ (-1.0, "'-1'"),
+ (float("nan"), "'NaN'"),
+ (float("inf"), "'Infinity'"),
+ (float("-inf"), "'-Infinity'"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_float(conn, val, expr, fmt_in):
+ assert isinstance(val, float)
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = {expr}::float8", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0.0, b"0.0"),
+ (1.0, b"1.0"),
+ (10000000000000000.0, b"1e+16"),
+ (1000000.1, b"1000000.1"),
+ (-100000.000001, b" -100000.000001"),
+ (-1.0, b" -1.0"),
+ (float("nan"), b"'NaN'::float8"),
+ (float("inf"), b"'Infinity'::float8"),
+ (float("-inf"), b"'-Infinity'::float8"),
+ ],
+)
+def test_quote_float(conn, val, expr):
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ r = cur.fetchone()
+ if isnan(val):
+ assert isnan(r[0]) and isnan(r[1])
+ else:
+ if isinstance(r[0], Decimal):
+ r = tuple(map(float, r))
+
+ assert r == (val, -val)
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (exp(1), "exp(1.0)"),
+ (-exp(1), "-exp(1.0)"),
+ (1e30, "'1e30'"),
+ (1e-30, "1e-30"),
+ (-1e30, "'-1e30'"),
+ (-1e-30, "-1e-30"),
+ ],
+)
+def test_dump_float_approx(conn, val, expr):
+ assert isinstance(val, float)
+ cur = conn.cursor()
+ cur.execute(f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, pgtype, want",
+ [
+ ("0", "float4", 0.0),
+ ("0.0", "float4", 0.0),
+ ("42", "float4", 42.0),
+ ("-42", "float4", -42.0),
+ ("0.0", "float8", 0.0),
+ ("0.0", "real", 0.0),
+ ("0.0", "double precision", 0.0),
+ ("0.0", "float4", 0.0),
+ ("nan", "float4", float("nan")),
+ ("inf", "float4", float("inf")),
+ ("-inf", "float4", -float("inf")),
+ ("nan", "float8", float("nan")),
+ ("inf", "float8", float("inf")),
+ ("-inf", "float8", -float("inf")),
+ ],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_float(conn, val, pgtype, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select %s::{pgtype}", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid
+ result = cur.fetchone()[0]
+
+ def check(result, want):
+ assert type(result) is type(want)
+ if isnan(want):
+ assert isnan(result)
+ elif isinf(want):
+ assert isinf(result)
+ assert (result < 0) is (want < 0)
+ else:
+ assert result == want
+
+ check(result, want)
+
+ cur.execute(f"select array[%s::{pgtype}]", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid
+ result = cur.fetchone()[0]
+ assert isinstance(result, list)
+ check(result[0], want)
+
+
+@pytest.mark.parametrize(
+ "expr, pgtype, want",
+ [
+ ("exp(1.0)", "float4", 2.71828),
+ ("-exp(1.0)", "float4", -2.71828),
+ ("exp(1.0)", "float8", 2.71828182845905),
+ ("-exp(1.0)", "float8", -2.71828182845905),
+ ("1.42e10", "float4", 1.42e10),
+ ("-1.42e10", "float4", -1.42e10),
+ ("1.42e40", "float8", 1.42e40),
+ ("-1.42e40", "float8", -1.42e40),
+ ],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_float_approx(conn, expr, pgtype, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::%s" % (expr, pgtype))
+ assert cur.pgresult.fformat(0) == fmt_out
+ result = cur.fetchone()[0]
+ assert result == pytest.approx(want)
+
+
+@pytest.mark.crdb_skip("copy")
+def test_load_float_copy(conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy("copy (select 3.14::float8, 'hi'::text) to stdout;") as copy:
+ copy.set_types(["float8", "text"])
+ rec = copy.read_row()
+
+ assert rec[0] == pytest.approx(3.14)
+ assert rec[1] == "hi"
+
+
+#
+# Tests with decimal
+#
+
+
+@pytest.mark.parametrize(
+ "val",
+ [
+ "0",
+ "-0",
+ "0.0",
+ "0.000000000000000000001",
+ "-0.000000000000000000001",
+ "nan",
+ "snan",
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_roundtrip_numeric(conn, val, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ val = Decimal(val)
+ cur.execute(f"select %{fmt_in.value}", (val,))
+ result = cur.fetchone()[0]
+ assert isinstance(result, Decimal)
+ if val.is_nan():
+ assert result.is_nan()
+ else:
+ assert result == val
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("0", b"0"),
+ ("0.0", b"0.0"),
+ ("0.00000000000000001", b"1E-17"),
+ ("-0.00000000000000001", b" -1E-17"),
+ ("nan", b"'NaN'::numeric"),
+ ("snan", b"'NaN'::numeric"),
+ ],
+)
+def test_quote_numeric(conn, val, expr):
+ val = Decimal(val)
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ r = cur.fetchone()
+
+ if val.is_nan():
+ assert isnan(r[0]) and isnan(r[1])
+ else:
+ assert r == (val, -val)
+
+
+@pytest.mark.crdb_skip("binary decimal")
+@pytest.mark.parametrize(
+ "expr",
+ ["NaN", "1", "1.0", "-1", "0.0", "0.01", "11", "1.1", "1.01", "0", "0.00"]
+ + [
+ "0.0000000",
+ "0.00001",
+ "1.00001",
+ "-1.00000000000000",
+ "-2.00000000000000",
+ "1000000000.12345",
+ "100.123456790000000000000000",
+ "1.0e-1000",
+ "1e1000",
+ "0.000000000000000000000000001",
+ "1.0000000000000000000000001",
+ "1000000000000000000000000.001",
+ "1000000000000000000000000000.001",
+ "9999999999999999999999999999.9",
+ ],
+)
+def test_dump_numeric_binary(conn, expr):
+ cur = conn.cursor()
+ val = Decimal(expr)
+ cur.execute("select %b::text, %s::decimal::text", [val, expr])
+ want, got = cur.fetchone()
+ assert got == want
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt_in",
+ [
+ f
+ if f != PyFormat.BINARY
+ else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal"))
+ for f in PyFormat
+ ],
+)
+def test_dump_numeric_exhaustive(conn, fmt_in):
+ cur = conn.cursor()
+
+ funcs = [
+ (lambda i: "1" + "0" * i),
+ (lambda i: "1" + "0" * i + "." + "0" * i),
+ (lambda i: "-1" + "0" * i),
+ (lambda i: "0." + "0" * i + "1"),
+ (lambda i: "-0." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "10"),
+ (lambda i: "1" + "0" * i + ".001"),
+ (lambda i: "9" + "9" * i),
+ (lambda i: "9" + "." + "9" * i),
+ (lambda i: "9" + "9" * i + ".9"),
+ (lambda i: "9" + "9" * i + "." + "9" * i),
+ (lambda i: "1.1e%s" % i),
+ (lambda i: "1.1e-%s" % i),
+ ]
+
+ for i in range(100):
+ for f in funcs:
+ expr = f(i)
+ val = Decimal(expr)
+ cur.execute(f"select %{fmt_in.value}::text, %s::decimal::text", [val, expr])
+ got, want = cur.fetchone()
+ assert got == want
+
+
+@pytest.mark.pg(">= 14")
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("inf", "Infinity"),
+ ("-inf", "-Infinity"),
+ ],
+)
+def test_dump_numeric_binary_inf(conn, val, expr):
+ cur = conn.cursor()
+ val = Decimal(val)
+ cur.execute("select %b", [val])
+
+
+@pytest.mark.parametrize(
+ "expr",
+ ["nan", "0", "1", "-1", "0.0", "0.01"]
+ + [
+ "0.0000000",
+ "-1.00000000000000",
+ "-2.00000000000000",
+ "1000000000.12345",
+ "100.123456790000000000000000",
+ "1.0e-1000",
+ "1e1000",
+ "0.000000000000000000000000001",
+ "1.0000000000000000000000001",
+ "1000000000000000000000000.001",
+ "1000000000000000000000000000.001",
+ "9999999999999999999999999999.9",
+ ],
+)
+def test_load_numeric_binary(conn, expr):
+ cur = conn.cursor(binary=1)
+ res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+ val = Decimal(expr)
+ if val.is_nan():
+ assert res.is_nan()
+ else:
+ assert res == val
+ if "e" not in expr:
+ assert str(res) == str(val)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_numeric_exhaustive(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+
+ funcs = [
+ (lambda i: "1" + "0" * i),
+ (lambda i: "1" + "0" * i + "." + "0" * i),
+ (lambda i: "-1" + "0" * i),
+ (lambda i: "0." + "0" * i + "1"),
+ (lambda i: "-0." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "10"),
+ (lambda i: "1" + "0" * i + ".001"),
+ (lambda i: "9" + "9" * i),
+ (lambda i: "9" + "." + "9" * i),
+ (lambda i: "9" + "9" * i + ".9"),
+ (lambda i: "9" + "9" * i + "." + "9" * i),
+ ]
+
+ for i in range(100):
+ for f in funcs:
+ snum = f(i)
+ want = Decimal(snum)
+ got = cur.execute(f"select '{snum}'::decimal").fetchone()[0]
+ assert want == got
+ assert str(want) == str(got)
+
+
+@pytest.mark.pg(">= 14")
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("inf", "Infinity"),
+ ("-inf", "-Infinity"),
+ ],
+)
+def test_load_numeric_binary_inf(conn, val, expr):
+ cur = conn.cursor(binary=1)
+ res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+ val = Decimal(val)
+ assert res == val
+
+
+@pytest.mark.parametrize(
+ "val",
+ [
+ "0",
+ "0.0",
+ "0.000000000000000000001",
+ "-0.000000000000000000001",
+ "nan",
+ ],
+)
+def test_numeric_as_float(conn, val):
+ cur = conn.cursor()
+ cur.adapters.register_loader("numeric", FloatLoader)
+
+ val = Decimal(val)
+ cur.execute("select %s as val", (val,))
+ result = cur.fetchone()[0]
+ assert isinstance(result, float)
+ if val.is_nan():
+ assert isnan(result)
+ else:
+ assert result == pytest.approx(float(val))
+
+ # the customization works with arrays too
+ cur.execute("select %s as arr", ([val],))
+ result = cur.fetchone()[0]
+ assert isinstance(result, list)
+ assert isinstance(result[0], float)
+ if val.is_nan():
+ assert isnan(result[0])
+ else:
+ assert result[0] == pytest.approx(float(val))
+
+
+#
+# Mixed tests
+#
+
+
+@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"])
+def test_minus_minus(conn, pgtype):
+ cur = conn.cursor()
+ cast = f"::{pgtype}" if pgtype is not None else ""
+ cur.execute(f"select -%s{cast}", [-1])
+ result = cur.fetchone()[0]
+ assert result == 1
+
+
+@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"])
+def test_minus_minus_quote(conn, pgtype):
+ cur = conn.cursor()
+ cast = f"::{pgtype}" if pgtype is not None else ""
+ cur.execute(sql.SQL("select -{}{}").format(sql.Literal(-1), sql.SQL(cast)))
+ result = cur.fetchone()[0]
+ assert result == 1
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ obj = wrapper(1)
+ cur = conn.execute(
+ f"select %(obj){fmt_in.value} = 1, %(obj){fmt_in.value}", {"obj": obj}
+ )
+ rec = cur.fetchone()
+ assert rec[0], rec[1]
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+def test_dump_wrapper_oid(wrapper):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ base = wrapper.__mro__[1]
+ assert base in (int, float)
+ n = base(3.14)
+ assert str(wrapper(n)) == str(n)
+ assert repr(wrapper(n)) == f"{wrapper.__name__}({n})"
+
+
+@pytest.mark.crdb("skip", reason="all types returned as bigint? TODOCRDB")
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_repr_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ cur = conn.execute(f"select pg_typeof(%{fmt_in.value})::oid", [wrapper(0)])
+ oid = cur.fetchone()[0]
+ assert oid == psycopg.postgres.types[wrapper.__name__.lower()].oid
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+ "typename",
+ "integer int2 int4 int8 float4 float8 numeric".split() + ["double precision"],
+)
+def test_oid_lookup(conn, typename, fmt_out):
+ dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out)
+ assert dumper.oid == conn.adapters.types[typename].oid
+ assert dumper.format == fmt_out
diff --git a/tests/types/test_range.py b/tests/types/test_range.py
new file mode 100644
index 0000000..1efd398
--- /dev/null
+++ b/tests/types/test_range.py
@@ -0,0 +1,677 @@
+import pickle
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg import pq, sql
+from psycopg import errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types import range as range_module
+from psycopg.types.range import Range, RangeInfo, register_range
+
+from ..utils import eur
+from ..fix_crdb import is_crdb, crdb_skip_message
+
+pytestmark = pytest.mark.crdb_skip("range")
+
+type2sub = {
+ "int4range": "int4",
+ "int8range": "int8",
+ "numrange": "numeric",
+ "daterange": "date",
+ "tsrange": "timestamp",
+ "tstzrange": "timestamptz",
+}
+
+tzinfo = dt.timezone(dt.timedelta(hours=2))
+
+samples = [
+ ("int4range", None, None, "()"),
+ ("int4range", 10, 20, "[]"),
+ ("int4range", -(2**31), (2**31) - 1, "[)"),
+ ("int8range", None, None, "()"),
+ ("int8range", 10, 20, "[)"),
+ ("int8range", -(2**63), (2**63) - 1, "[)"),
+ ("numrange", Decimal(-100), Decimal("100.123"), "(]"),
+ ("numrange", Decimal(100), None, "()"),
+ ("numrange", None, Decimal(100), "()"),
+ ("daterange", dt.date(2000, 1, 1), dt.date(2020, 1, 1), "[)"),
+ (
+ "tsrange",
+ dt.datetime(2000, 1, 1, 00, 00),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999),
+ "[]",
+ ),
+ (
+ "tstzrange",
+ dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+ "()",
+ ),
+]
+
+range_names = """
+ int4range int8range numrange daterange tsrange tstzrange
+ """.split()
+
+range_classes = """
+ Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+ """.split()
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty(conn, pgtype, fmt_in):
+ r = Range(empty=True) # type: ignore[var-annotated]
+ cur = conn.execute(f"select 'empty'::{pgtype} = %{fmt_in.value}", (r,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", range_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(range_module, wrapper)
+ r = wrapper(empty=True)
+ cur = conn.execute(f"select 'empty' = %{fmt_in.value}", (r,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize(
+ "fmt_in",
+ [
+ PyFormat.AUTO,
+ PyFormat.TEXT,
+ # There are many ways to work around this (use text, use a cast on the
+ # placeholder, use specific Range subclasses).
+ pytest.param(
+ PyFormat.BINARY,
+ marks=pytest.mark.xfail(
+ reason="can't dump an array of untypes binary range without cast"
+ ),
+ ),
+ ],
+)
+def test_dump_builtin_array(conn, pgtype, fmt_in):
+ r1 = Range(empty=True) # type: ignore[var-annotated]
+ r2 = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in.value}",
+ ([r1, r2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in):
+ r1 = Range(empty=True) # type: ignore[var-annotated]
+ r2 = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"select array['empty'::{pgtype}, '(,)'::{pgtype}] "
+ f"= %{fmt_in.value}::{pgtype}[]",
+ ([r1, r2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", range_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(range_module, wrapper)
+ r1 = wrapper(empty=True)
+ r2 = wrapper(bounds="()")
+ cur = conn.execute(f"""select '{{empty,"(,)"}}' = %{fmt_in.value}""", ([r1, r2],))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in):
+ r = Range(min, max, bounds) # type: ignore[var-annotated]
+ sub = type2sub[pgtype]
+ cur = conn.execute(
+ f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %{fmt_in.value}",
+ (min, max, bounds, r),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_empty(conn, pgtype, fmt_out):
+ r = Range(empty=True) # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
+ assert type(got) is Range
+ assert got == r
+ assert not got
+ assert got.isempty
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_inf(conn, pgtype, fmt_out):
+ r = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select '(,)'::{pgtype}").fetchone()
+ assert type(got) is Range
+ assert got == r
+ assert got
+ assert not got.isempty
+ assert got.lower_inf
+ assert got.upper_inf
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_array(conn, pgtype, fmt_out):
+ r1 = Range(empty=True) # type: ignore[var-annotated]
+ r2 = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select array['empty'::{pgtype}, '(,)'::{pgtype}]").fetchone()
+ assert got == [r1, r2]
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out):
+ r = Range(min, max, bounds) # type: ignore[var-annotated]
+ sub = type2sub[pgtype]
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds))
+ # normalise discrete ranges
+ if r.upper_inc and isinstance(r.upper, int):
+ bounds = "[)" if r.lower_inc else "()"
+ r = type(r)(r.lower, r.upper + 1, bounds)
+ assert cur.fetchone()[0] == r
+
+
+@pytest.mark.parametrize(
+ "min, max, bounds",
+ [
+ ("2000,1,1", "2001,1,1", "[)"),
+ ("2000,1,1", None, "[)"),
+ (None, "2001,1,1", "()"),
+ (None, None, "()"),
+ (None, None, "empty"),
+ ],
+)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in(conn, min, max, bounds, format):
+ cur = conn.cursor()
+ cur.execute("create table copyrange (id serial primary key, r daterange)")
+
+ if bounds != "empty":
+ min = dt.date(*map(int, min.split(","))) if min else None
+ max = dt.date(*map(int, max.split(","))) if max else None
+ r = Range[dt.date](min, max, bounds)
+ else:
+ r = Range(empty=True)
+
+ try:
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
+ copy.write_row([r])
+ except e.ProtocolViolation:
+ if not min and not max and format == pq.Format.BINARY:
+ pytest.xfail("TODO: add annotation to dump ranges with no type info")
+ else:
+ raise
+
+ rec = cur.execute("select r from copyrange order by id").fetchone()
+ assert rec[0] == r
+
+
+@pytest.mark.parametrize("bounds", "() empty".split())
+@pytest.mark.parametrize("wrapper", range_classes)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_wrappers(conn, bounds, wrapper, format):
+ cur = conn.cursor()
+ cur.execute("create table copyrange (id serial primary key, r daterange)")
+
+ cls = getattr(range_module, wrapper)
+ r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds)
+
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
+ copy.write_row([r])
+
+ rec = cur.execute("select r from copyrange order by id").fetchone()
+ assert rec[0] == r
+
+
+@pytest.mark.parametrize("bounds", "() empty".split())
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_set_type(conn, bounds, pgtype, format):
+ cur = conn.cursor()
+ cur.execute(f"create table copyrange (id serial primary key, r {pgtype})")
+
+ if bounds == "empty":
+ r = Range(empty=True) # type: ignore[var-annotated]
+ else:
+ r = Range(None, None, bounds)
+
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
+ copy.set_types([pgtype])
+ copy.write_row([r])
+
+ rec = cur.execute("select r from copyrange order by id").fetchone()
+ assert rec[0] == r
+
+
+@pytest.fixture(scope="session")
+def testrange(svcconn):
+ create_test_range(svcconn)
+
+
+def create_test_range(conn):
+ if is_crdb(conn):
+ pytest.skip(crdb_skip_message("range"))
+
+ conn.execute(
+ """
+ create schema if not exists testschema;
+
+ drop type if exists testrange cascade;
+ drop type if exists testschema.testrange cascade;
+
+ create type testrange as range (subtype = text, collation = "C");
+ create type testschema.testrange as range (subtype = float8);
+ """
+ )
+
+
+fetch_cases = [
+ ("testrange", "text"),
+ ("testschema.testrange", "float8"),
+ (sql.Identifier("testrange"), "text"),
+ (sql.Identifier("testschema", "testrange"), "float8"),
+]
+
+
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+def test_fetch_info(conn, testrange, name, subtype):
+ info = RangeInfo.fetch(conn, name)
+ assert info.name == "testrange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == conn.adapters.types[subtype].oid
+
+
+def test_fetch_info_not_found(conn):
+ assert RangeInfo.fetch(conn, "nosuchrange") is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+async def test_fetch_info_async(aconn, testrange, name, subtype):
+ info = await RangeInfo.fetch(aconn, name)
+ assert info.name == "testrange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == aconn.adapters.types[subtype].oid
+
+
+@pytest.mark.asyncio
+async def test_fetch_info_not_found_async(aconn):
+ assert await RangeInfo.fetch(aconn, "nosuchrange") is None
+
+
+def test_dump_custom_empty(conn, testrange):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+
+ r = Range[str](empty=True)
+ cur = conn.execute("select 'empty'::testrange = %s", (r,))
+ assert cur.fetchone()[0] is True
+
+
+def test_dump_quoting(conn, testrange):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+ cur = conn.cursor()
+ for i in range(1, 254):
+ cur.execute(
+ """
+ select ascii(lower(%(r)s)) = %(low)s
+ and ascii(upper(%(r)s)) = %(up)s
+ """,
+ {"r": Range(chr(i), chr(i + 1)), "low": i, "up": i + 1},
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_custom_empty(conn, testrange, fmt_out):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute("select 'empty'::testrange").fetchone()
+ assert isinstance(got, Range)
+ assert got.isempty
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_quoting(conn, testrange, fmt_out):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 254):
+ cur.execute(
+ "select testrange(chr(%(low)s::int), chr(%(up)s::int))",
+ {"low": i, "up": i + 1},
+ )
+ got: Range[str] = cur.fetchone()[0]
+ assert isinstance(got, Range)
+ assert got.lower and ord(got.lower) == i
+ assert got.upper and ord(got.upper) == i + 1
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_mixed_array_types(conn, fmt_out):
+ conn.execute("create table testmix (a daterange[], b tstzrange[])")
+ r1 = Range(dt.date(2000, 1, 1), dt.date(2001, 1, 1), "[)")
+ r2 = Range(
+ dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc),
+ dt.datetime(2001, 1, 1, tzinfo=dt.timezone.utc),
+ "[)",
+ )
+ conn.execute("insert into testmix values (%s, %s)", [[r1], [r2]])
+ got = conn.execute("select * from testmix").fetchone()
+ assert got == ([r1], [r2])
+
+
+class TestRangeObject:
+ def test_noparam(self):
+ r = Range() # type: ignore[var-annotated]
+
+ assert not r.isempty
+ assert r.lower is None
+ assert r.upper is None
+ assert r.lower_inf
+ assert r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ def test_empty(self):
+ r = Range(empty=True) # type: ignore[var-annotated]
+
+ assert r.isempty
+ assert r.lower is None
+ assert r.upper is None
+ assert not r.lower_inf
+ assert not r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ def test_nobounds(self):
+ r = Range(10, 20)
+ assert r.lower == 10
+ assert r.upper == 20
+ assert not r.isempty
+ assert not r.lower_inf
+ assert not r.upper_inf
+ assert r.lower_inc
+ assert not r.upper_inc
+
+ def test_bounds(self):
+ for bounds, lower_inc, upper_inc in [
+ ("[)", True, False),
+ ("(]", False, True),
+ ("()", False, False),
+ ("[]", True, True),
+ ]:
+ r = Range(10, 20, bounds)
+ assert r.bounds == bounds
+ assert r.lower == 10
+ assert r.upper == 20
+ assert not r.isempty
+ assert not r.lower_inf
+ assert not r.upper_inf
+ assert r.lower_inc == lower_inc
+ assert r.upper_inc == upper_inc
+
+ def test_keywords(self):
+ r = Range(upper=20)
+ r.lower is None
+ r.upper == 20
+ assert not r.isempty
+ assert r.lower_inf
+ assert not r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ r = Range(lower=10, bounds="(]")
+ r.lower == 10
+ r.upper is None
+ assert not r.isempty
+ assert not r.lower_inf
+ assert r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ def test_bad_bounds(self):
+ with pytest.raises(ValueError):
+ Range(bounds="(")
+ with pytest.raises(ValueError):
+ Range(bounds="[}")
+
+ def test_in(self):
+ r = Range[int](empty=True)
+ assert 10 not in r
+ assert "x" not in r # type: ignore[operator]
+
+ r = Range()
+ assert 10 in r
+
+ r = Range(lower=10, bounds="[)")
+ assert 9 not in r
+ assert 10 in r
+ assert 11 in r
+
+ r = Range(lower=10, bounds="()")
+ assert 9 not in r
+ assert 10 not in r
+ assert 11 in r
+
+ r = Range(upper=20, bounds="()")
+ assert 19 in r
+ assert 20 not in r
+ assert 21 not in r
+
+ r = Range(upper=20, bounds="(]")
+ assert 19 in r
+ assert 20 in r
+ assert 21 not in r
+
+ r = Range(10, 20)
+ assert 9 not in r
+ assert 10 in r
+ assert 11 in r
+ assert 19 in r
+ assert 20 not in r
+ assert 21 not in r
+
+ r = Range(10, 20, "(]")
+ assert 9 not in r
+ assert 10 not in r
+ assert 11 in r
+ assert 19 in r
+ assert 20 in r
+ assert 21 not in r
+
+ r = Range(20, 10)
+ assert 9 not in r
+ assert 10 not in r
+ assert 11 not in r
+ assert 19 not in r
+ assert 20 not in r
+ assert 21 not in r
+
+ def test_nonzero(self):
+ assert Range()
+ assert Range(10, 20)
+ assert not Range(empty=True)
+
+ def test_eq_hash(self):
+ def assert_equal(r1, r2):
+ assert r1 == r2
+ assert hash(r1) == hash(r2)
+
+ assert_equal(Range(empty=True), Range(empty=True))
+ assert_equal(Range(), Range())
+ assert_equal(Range(10, None), Range(10, None))
+ assert_equal(Range(10, 20), Range(10, 20))
+ assert_equal(Range(10, 20), Range(10, 20, "[)"))
+ assert_equal(Range(10, 20, "[]"), Range(10, 20, "[]"))
+
+ def assert_not_equal(r1, r2):
+ assert r1 != r2
+ assert hash(r1) != hash(r2)
+
+ assert_not_equal(Range(10, 20), Range(10, 21))
+ assert_not_equal(Range(10, 20), Range(11, 20))
+ assert_not_equal(Range(10, 20, "[)"), Range(10, 20, "[]"))
+
+ def test_eq_wrong_type(self):
+ assert Range(10, 20) != ()
+
+ # as the postgres docs describe for the server-side stuff,
+ # ordering is rather arbitrary, but will remain stable
+ # and consistent.
+
+ def test_lt_ordering(self):
+ assert Range(empty=True) < Range(0, 4)
+ assert not Range(1, 2) < Range(0, 4)
+ assert Range(0, 4) < Range(1, 2)
+ assert not Range(1, 2) < Range()
+ assert Range() < Range(1, 2)
+ assert not Range(1) < Range(upper=1)
+ assert not Range() < Range()
+ assert not Range(empty=True) < Range(empty=True)
+ assert not Range(1, 2) < Range(1, 2)
+ with pytest.raises(TypeError):
+ assert 1 < Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not Range(1, 2) < 1
+
+ def test_gt_ordering(self):
+ assert not Range(empty=True) > Range(0, 4)
+ assert Range(1, 2) > Range(0, 4)
+ assert not Range(0, 4) > Range(1, 2)
+ assert Range(1, 2) > Range()
+ assert not Range() > Range(1, 2)
+ assert Range(1) > Range(upper=1)
+ assert not Range() > Range()
+ assert not Range(empty=True) > Range(empty=True)
+ assert not Range(1, 2) > Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not 1 > Range(1, 2)
+ with pytest.raises(TypeError):
+ assert Range(1, 2) > 1
+
+ def test_le_ordering(self):
+ assert Range(empty=True) <= Range(0, 4)
+ assert not Range(1, 2) <= Range(0, 4)
+ assert Range(0, 4) <= Range(1, 2)
+ assert not Range(1, 2) <= Range()
+ assert Range() <= Range(1, 2)
+ assert not Range(1) <= Range(upper=1)
+ assert Range() <= Range()
+ assert Range(empty=True) <= Range(empty=True)
+ assert Range(1, 2) <= Range(1, 2)
+ with pytest.raises(TypeError):
+ assert 1 <= Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not Range(1, 2) <= 1
+
+ def test_ge_ordering(self):
+ assert not Range(empty=True) >= Range(0, 4)
+ assert Range(1, 2) >= Range(0, 4)
+ assert not Range(0, 4) >= Range(1, 2)
+ assert Range(1, 2) >= Range()
+ assert not Range() >= Range(1, 2)
+ assert Range(1) >= Range(upper=1)
+ assert Range() >= Range()
+ assert Range(empty=True) >= Range(empty=True)
+ assert Range(1, 2) >= Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not 1 >= Range(1, 2)
+ with pytest.raises(TypeError):
+ (Range(1, 2) >= 1)
+
+ def test_pickling(self):
+ r = Range(0, 4)
+ assert pickle.loads(pickle.dumps(r)) == r
+
+ def test_str(self):
+ """
+ Range types should have a short and readable ``str`` implementation.
+ """
+ expected = [
+ "(0, 4)",
+ "[0, 4]",
+ "(0, 4]",
+ "[0, 4)",
+ "empty",
+ ]
+ results = []
+
+ for bounds in ("()", "[]", "(]", "[)"):
+ r = Range(0, 4, bounds=bounds)
+ results.append(str(r))
+
+ r = Range(empty=True)
+ results.append(str(r))
+ assert results == expected
+
+ def test_str_datetime(self):
+ """
+ Date-Time ranges should return a human-readable string as well on
+ string conversion.
+ """
+ tz = dt.timezone(dt.timedelta(hours=-5))
+ r = Range(
+ dt.datetime(2010, 1, 1, tzinfo=tz),
+ dt.datetime(2011, 1, 1, tzinfo=tz),
+ )
+ expected = "[2010-01-01 00:00:00-05:00, 2011-01-01 00:00:00-05:00)"
+ result = str(r)
+ assert result == expected
+
+ def test_exclude_inf_bounds(self):
+ r = Range(None, 10, "[]")
+ assert r.lower is None
+ assert not r.lower_inc
+ assert r.bounds == "(]"
+
+ r = Range(10, None, "[]")
+ assert r.upper is None
+ assert not r.upper_inc
+ assert r.bounds == "[)"
+
+ r = Range(None, None, "[]")
+ assert r.lower is None
+ assert not r.lower_inc
+ assert r.upper is None
+ assert not r.upper_inc
+ assert r.bounds == "()"
+
+
+def test_no_info_error(conn):
+ with pytest.raises(TypeError, match="range"):
+ register_range(None, conn) # type: ignore[arg-type]
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"])
+def test_literal_invalid_name(conn, name):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(f'create type "{name}" as range (subtype = text)')
+ info = RangeInfo.fetch(conn, f'"{name}"')
+ register_range(info, conn)
+ obj = Range("a", "z", "[]")
+ assert sql.Literal(obj).as_string(conn) == f"'[a,z]'::\"{name}\""
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ assert cur.fetchone()[0] == obj
diff --git a/tests/types/test_shapely.py b/tests/types/test_shapely.py
new file mode 100644
index 0000000..0f7007e
--- /dev/null
+++ b/tests/types/test_shapely.py
@@ -0,0 +1,152 @@
+import pytest
+
+import psycopg
+from psycopg.pq import Format
+from psycopg.types import TypeInfo
+from psycopg.adapt import PyFormat
+
+pytest.importorskip("shapely")
+
+from shapely.geometry import Point, Polygon, MultiPolygon # noqa: E402
+from psycopg.types.shapely import register_shapely # noqa: E402
+
+pytestmark = [
+ pytest.mark.postgis,
+ pytest.mark.crdb("skip"),
+]
+
+# real example, with CRS and "holes"
+MULTIPOLYGON_GEOJSON = """
+{
+ "type":"MultiPolygon",
+ "crs":{
+ "type":"name",
+ "properties":{
+ "name":"EPSG:3857"
+ }
+ },
+ "coordinates":[
+ [
+ [
+ [89574.61111389, 6894228.638802719],
+ [89576.815239808, 6894208.60747024],
+ [89576.904295401, 6894207.820852726],
+ [89577.99522641, 6894208.022080451],
+ [89577.961830563, 6894209.229446936],
+ [89589.227363031, 6894210.601454523],
+ [89594.615226386, 6894161.849595264],
+ [89600.314784314, 6894111.37846976],
+ [89651.187791607, 6894116.774968589],
+ [89648.49385993, 6894140.226914071],
+ [89642.92788539, 6894193.423936413],
+ [89639.721884055, 6894224.08372821],
+ [89589.283022777, 6894218.431048969],
+ [89588.192091767, 6894230.248628867],
+ [89574.61111389, 6894228.638802719]
+ ],
+ [
+ [89610.344670435, 6894182.466199101],
+ [89625.985058891, 6894184.258949757],
+ [89629.547282597, 6894153.270030369],
+ [89613.918026089, 6894151.458993318],
+ [89610.344670435, 6894182.466199101]
+ ]
+ ]
+ ]
+}"""
+
+SAMPLE_POINT_GEOJSON = '{"type":"Point","coordinates":[1.2, 3.4]}'
+
+
+@pytest.fixture
+def shapely_conn(conn, svcconn):
+ try:
+ with svcconn.transaction():
+ svcconn.execute("create extension if not exists postgis")
+ except psycopg.Error as e:
+ pytest.skip(f"can't create extension postgis: {e}")
+
+ info = TypeInfo.fetch(conn, "geometry")
+ assert info
+ register_shapely(info, conn)
+ return conn
+
+
+def test_no_adapter(conn):
+ point = Point(1.2, 3.4)
+ with pytest.raises(psycopg.ProgrammingError, match="cannot adapt type 'Point'"):
+ conn.execute("SELECT pg_typeof(%s)", [point]).fetchone()[0]
+
+
+def test_no_info_error(conn):
+ from psycopg.types.shapely import register_shapely
+
+ with pytest.raises(TypeError, match="postgis.*extension"):
+ register_shapely(None, conn) # type: ignore[arg-type]
+
+
+def test_with_adapter(shapely_conn):
+ SAMPLE_POINT = Point(1.2, 3.4)
+ SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
+
+ assert (
+ shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POINT]).fetchone()[0]
+ == "geometry"
+ )
+
+ assert (
+ shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POLYGON]).fetchone()[0]
+ == "geometry"
+ )
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", Format)
+def test_write_read_shape(shapely_conn, fmt_in, fmt_out):
+ SAMPLE_POINT = Point(1.2, 3.4)
+ SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
+
+ with shapely_conn.cursor(binary=fmt_out) as cur:
+ cur.execute(
+ """
+ create table sample_geoms(
+ id INTEGER PRIMARY KEY,
+ geom geometry
+ )
+ """
+ )
+ cur.execute(
+ f"insert into sample_geoms(id, geom) VALUES(1, %{fmt_in})",
+ (SAMPLE_POINT,),
+ )
+ cur.execute(
+ f"insert into sample_geoms(id, geom) VALUES(2, %{fmt_in})",
+ (SAMPLE_POLYGON,),
+ )
+
+ cur.execute("select geom from sample_geoms where id=1")
+ result = cur.fetchone()[0]
+ assert result == SAMPLE_POINT
+
+ cur.execute("select geom from sample_geoms where id=2")
+ result = cur.fetchone()[0]
+ assert result == SAMPLE_POLYGON
+
+
+@pytest.mark.parametrize("fmt_out", Format)
+def test_match_geojson(shapely_conn, fmt_out):
+ SAMPLE_POINT = Point(1.2, 3.4)
+ with shapely_conn.cursor(binary=fmt_out) as cur:
+ cur.execute(
+ """
+ select ST_GeomFromGeoJSON(%s)
+ """,
+ (SAMPLE_POINT_GEOJSON,),
+ )
+ result = cur.fetchone()[0]
+ # clone the coordinates to have a list instead of a shapely wrapper
+ assert result.coords[:] == SAMPLE_POINT.coords[:]
+ #
+ cur.execute("select ST_GeomFromGeoJSON(%s)", (MULTIPOLYGON_GEOJSON,))
+ result = cur.fetchone()[0]
+ assert isinstance(result, MultiPolygon)
diff --git a/tests/types/test_string.py b/tests/types/test_string.py
new file mode 100644
index 0000000..d23e5e0
--- /dev/null
+++ b/tests/types/test_string.py
@@ -0,0 +1,307 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg import errors as e
+from psycopg.adapt import PyFormat
+from psycopg import Binary
+
+from ..utils import eur
+from ..fix_crdb import crdb_encoding, crdb_scs_off
+
+#
+# tests with text
+#
+
+
+def crdb_bpchar(*args):
+ return pytest.param(*args, marks=pytest.mark.crdb("skip", reason="bpchar"))
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_1char(conn, fmt_in):
+ cur = conn.cursor()
+ for i in range(1, 256):
+ cur.execute(f"select %{fmt_in.value} = chr(%s)", (chr(i), i))
+ assert cur.fetchone()[0] is True, chr(i)
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_quote_1char(conn, scs):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ conn.execute("set escape_string_warning to on")
+
+ cur = conn.cursor()
+ query = sql.SQL("select {ch} = chr(%s)")
+ for i in range(1, 256):
+ if chr(i) == "%":
+ continue
+ cur.execute(query.format(ch=sql.Literal(chr(i))), (i,))
+ assert cur.fetchone()[0] is True, chr(i)
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
+
+@pytest.mark.crdb("skip", reason="can deal with 0 strings")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_zero(conn, fmt_in):
+ cur = conn.cursor()
+ s = "foo\x00bar"
+ with pytest.raises(psycopg.DataError):
+ cur.execute(f"select %{fmt_in.value}::text", (s,))
+
+
+def test_quote_zero(conn):
+ cur = conn.cursor()
+ s = "foo\x00bar"
+ with pytest.raises(psycopg.DataError):
+ cur.execute(sql.SQL("select {}").format(sql.Literal(s)))
+
+
+# the only way to make this pass is to reduce %% -> % every time
+# not only when there are query arguments
+# see https://github.com/psycopg/psycopg2/issues/825
+@pytest.mark.xfail
+def test_quote_percent(conn):
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {ch}").format(ch=sql.Literal("%")))
+ assert cur.fetchone()[0] == "%"
+
+ cur.execute(
+ sql.SQL("select {ch} = chr(%s)").format(ch=sql.Literal("%")),
+ (ord("%"),),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "typename", ["text", "varchar", "name", crdb_bpchar("bpchar"), '"char"']
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_1char(conn, typename, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ if typename == '"char"' and i > 127:
+ # for char > 128 the client receives only 194 or 195.
+ continue
+
+ cur.execute(f"select chr(%s)::{typename}", (i,))
+ res = cur.fetchone()[0]
+ assert res == chr(i)
+
+ assert cur.pgresult.fformat(0) == fmt_out
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize(
+ "encoding", ["utf8", crdb_encoding("latin9"), crdb_encoding("sql_ascii")]
+)
+def test_dump_enc(conn, fmt_in, encoding):
+ cur = conn.cursor()
+
+ conn.execute(f"set client_encoding to {encoding}")
+ (res,) = cur.execute(f"select ascii(%{fmt_in.value})", (eur,)).fetchone()
+ assert res == ord(eur)
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_badenc(conn, fmt_in):
+ cur = conn.cursor()
+
+ conn.execute("set client_encoding to latin1")
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute(f"select %{fmt_in.value}::bytea", (eur,))
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_utf8_badenc(conn, fmt_in):
+ cur = conn.cursor()
+
+ conn.execute("set client_encoding to utf8")
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute(f"select %{fmt_in.value}", ("\uddf8",))
+
+
+@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT])
+def test_dump_enum(conn, fmt_in):
+ from enum import Enum
+
+ class MyEnum(str, Enum):
+ foo = "foo"
+ bar = "bar"
+
+ cur = conn.cursor()
+ cur.execute("create type myenum as enum ('foo', 'bar')")
+ cur.execute("create table with_enum (e myenum)")
+ cur.execute(f"insert into with_enum (e) values (%{fmt_in.value})", (MyEnum.foo,))
+ (res,) = cur.execute("select e from with_enum").fetchone()
+ assert res == "foo"
+
+
+@pytest.mark.crdb("skip")
+@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT])
+def test_dump_text_oid(conn, fmt_in):
+ conn.autocommit = True
+
+ with pytest.raises(e.IndeterminateDatatype):
+ conn.execute(f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"])
+ conn.adapters.register_dumper(str, psycopg.types.string.StrDumper)
+ cur = conn.execute(
+ f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"]
+ )
+ assert cur.fetchone()[0] == "foobar"
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_load_enc(conn, typename, encoding, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+
+ conn.execute(f"set client_encoding to {encoding}")
+ (res,) = cur.execute(f"select chr(%s)::{typename}", [ord(eur)]).fetchone()
+ assert res == eur
+
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ (res,) = copy.read_row()
+
+ assert res == eur
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_load_badenc(conn, typename, fmt_out):
+ conn.autocommit = True
+ cur = conn.cursor(binary=fmt_out)
+
+ conn.execute("set client_encoding to latin1")
+ with pytest.raises(psycopg.DataError):
+ cur.execute(f"select chr(%s)::{typename}", [ord(eur)])
+
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ with pytest.raises(psycopg.DataError):
+ copy.read_row()
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_load_ascii(conn, typename, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+
+ conn.execute("set client_encoding to sql_ascii")
+ cur.execute(f"select chr(%s)::{typename}", [ord(eur)])
+ assert cur.fetchone()[0] == eur.encode()
+
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ (res,) = copy.read_row()
+
+ assert res == eur.encode()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", crdb_bpchar("bpchar")])
+def test_text_array(conn, typename, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ a = list(map(chr, range(1, 256))) + [eur]
+
+ (res,) = cur.execute(f"select %{fmt_in.value}::{typename}[]", (a,)).fetchone()
+ assert res == a
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_text_array_ascii(conn, fmt_in, fmt_out):
+ conn.execute("set client_encoding to sql_ascii")
+ cur = conn.cursor(binary=fmt_out)
+ a = list(map(chr, range(1, 256))) + [eur]
+ exp = [s.encode() for s in a]
+ (res,) = cur.execute(f"select %{fmt_in.value}::text[]", (a,)).fetchone()
+ assert res == exp
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name"])
+def test_oid_lookup(conn, typename, fmt_out):
+ dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out)
+ assert dumper.oid == conn.adapters.types[typename].oid
+ assert dumper.format == fmt_out
+
+
+#
+# tests with bytea
+#
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary])
+def test_dump_1byte(conn, fmt_in, pytype):
+ cur = conn.cursor()
+ for i in range(0, 256):
+ obj = pytype(bytes([i]))
+ cur.execute(f"select %{fmt_in.value} = set_byte('x', 0, %s)", (obj, i))
+ assert cur.fetchone()[0] is True, i
+
+ cur.execute(f"select %{fmt_in.value} = array[set_byte('x', 0, %s)]", ([obj], i))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary])
+def test_quote_1byte(conn, scs, pytype):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ conn.execute("set escape_string_warning to on")
+
+ cur = conn.cursor()
+ query = sql.SQL("select {ch} = set_byte('x', 0, %s)")
+ for i in range(0, 256):
+ obj = pytype(bytes([i]))
+ cur.execute(query.format(ch=sql.Literal(obj)), (i,))
+ assert cur.fetchone()[0] is True, i
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_1byte(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(0, 256):
+ cur.execute("select set_byte('x', 0, %s)", (i,))
+ val = cur.fetchone()[0]
+ assert val == bytes([i])
+
+ assert isinstance(val, bytes)
+ assert cur.pgresult.fformat(0) == fmt_out
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_bytea_array(conn, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ a = [bytes(range(0, 256))]
+ (res,) = cur.execute(f"select %{fmt_in.value}::bytea[]", (a,)).fetchone()
+ assert res == a
diff --git a/tests/types/test_uuid.py b/tests/types/test_uuid.py
new file mode 100644
index 0000000..f86f066
--- /dev/null
+++ b/tests/types/test_uuid.py
@@ -0,0 +1,56 @@
+import sys
+from uuid import UUID
+import subprocess as sp
+
+import pytest
+
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_uuid_dump(conn, fmt_in):
+ val = "12345678123456781234567812345679"
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = %s::uuid", (UUID(val), val))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_uuid_load(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ val = "12345678123456781234567812345679"
+ cur.execute("select %s::uuid", (val,))
+ assert cur.fetchone()[0] == UUID(val)
+
+ stmt = sql.SQL("copy (select {}::uuid) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["uuid"])
+ (res,) = copy.read_row()
+
+ assert res == UUID(val)
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+def test_lazy_load(dsn):
+ script = f"""\
+import sys
+import psycopg
+
+assert 'uuid' not in sys.modules
+
+conn = psycopg.connect({dsn!r})
+with conn.cursor() as cur:
+ cur.execute("select repeat('1', 32)::uuid")
+ cur.fetchone()
+
+conn.close()
+assert 'uuid' in sys.modules
+"""
+
+ sp.check_call([sys.executable, "-c", script])
diff --git a/tests/typing_example.py b/tests/typing_example.py
new file mode 100644
index 0000000..a26ca49
--- /dev/null
+++ b/tests/typing_example.py
@@ -0,0 +1,176 @@
+# flake8: builtins=reveal_type
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
+
+from psycopg import Connection, Cursor, ServerCursor, connect, rows
+from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor
+
+
+def int_row_factory(
+ cursor: Union[Cursor[Any], AsyncCursor[Any]]
+) -> Callable[[Sequence[int]], int]:
+ return lambda values: values[0] if values else 42
+
+
+@dataclass
+class Person:
+ name: str
+ address: str
+
+ @classmethod
+ def row_factory(
+ cls, cursor: Union[Cursor[Any], AsyncCursor[Any]]
+ ) -> Callable[[Sequence[str]], Person]:
+ def mkrow(values: Sequence[str]) -> Person:
+ name, address = values
+ return cls(name, address)
+
+ return mkrow
+
+
+def kwargsf(*, foo: int, bar: int, baz: int) -> int:
+ return 42
+
+
+def argsf(foo: int, bar: int, baz: int) -> float:
+ return 42.0
+
+
+def check_row_factory_cursor() -> None:
+ """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
+ conn = connect()
+
+ cur1: Cursor[Any]
+ cur1 = conn.cursor()
+ r1: Optional[Any]
+ r1 = cur1.fetchone()
+ r1 is not None
+
+ cur2: Cursor[int]
+ r2: Optional[int]
+ with conn.cursor(row_factory=int_row_factory) as cur2:
+ cur2.execute("select 1")
+ r2 = cur2.fetchone()
+ r2 and r2 > 0
+
+ cur3: ServerCursor[Person]
+ persons: Sequence[Person]
+ with conn.cursor(name="s", row_factory=Person.row_factory) as cur3:
+ cur3.execute("select * from persons where name like 'al%'")
+ persons = cur3.fetchall()
+ persons[0].address
+
+
+async def async_check_row_factory_cursor() -> None:
+ """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
+ conn = await AsyncConnection.connect()
+
+ cur1: AsyncCursor[Any]
+ cur1 = conn.cursor()
+ r1: Optional[Any]
+ r1 = await cur1.fetchone()
+ r1 is not None
+
+ cur2: AsyncCursor[int]
+ r2: Optional[int]
+ async with conn.cursor(row_factory=int_row_factory) as cur2:
+ await cur2.execute("select 1")
+ r2 = await cur2.fetchone()
+ r2 and r2 > 0
+
+ cur3: AsyncServerCursor[Person]
+ persons: Sequence[Person]
+ async with conn.cursor(name="s", row_factory=Person.row_factory) as cur3:
+ await cur3.execute("select * from persons where name like 'al%'")
+ persons = await cur3.fetchall()
+ persons[0].address
+
+
+def check_row_factory_connection() -> None:
+ """Type-check connect(..., row_factory=<MyRowFactory>) or
+ Connection.row_factory cases.
+ """
+ conn1: Connection[int]
+ cur1: Cursor[int]
+ r1: Optional[int]
+ conn1 = connect(row_factory=int_row_factory)
+ cur1 = conn1.execute("select 1")
+ r1 = cur1.fetchone()
+ r1 != 0
+ with conn1.cursor() as cur1:
+ cur1.execute("select 2")
+
+ conn2: Connection[Person]
+ cur2: Cursor[Person]
+ r2: Optional[Person]
+ conn2 = connect(row_factory=Person.row_factory)
+ cur2 = conn2.execute("select * from persons")
+ r2 = cur2.fetchone()
+ r2 and r2.name
+ with conn2.cursor() as cur2:
+ cur2.execute("select 2")
+
+ cur3: Cursor[Tuple[Any, ...]]
+ r3: Optional[Tuple[Any, ...]]
+ conn3 = connect()
+ cur3 = conn3.execute("select 3")
+ with conn3.cursor() as cur3:
+ cur3.execute("select 42")
+ r3 = cur3.fetchone()
+ r3 and len(r3)
+
+
+async def async_check_row_factory_connection() -> None:
+ """Type-check connect(..., row_factory=<MyRowFactory>) or
+ Connection.row_factory cases.
+ """
+ conn1: AsyncConnection[int]
+ cur1: AsyncCursor[int]
+ r1: Optional[int]
+ conn1 = await AsyncConnection.connect(row_factory=int_row_factory)
+ cur1 = await conn1.execute("select 1")
+ r1 = await cur1.fetchone()
+ r1 != 0
+ async with conn1.cursor() as cur1:
+ await cur1.execute("select 2")
+
+ conn2: AsyncConnection[Person]
+ cur2: AsyncCursor[Person]
+ r2: Optional[Person]
+ conn2 = await AsyncConnection.connect(row_factory=Person.row_factory)
+ cur2 = await conn2.execute("select * from persons")
+ r2 = await cur2.fetchone()
+ r2 and r2.name
+ async with conn2.cursor() as cur2:
+ await cur2.execute("select 2")
+
+ cur3: AsyncCursor[Tuple[Any, ...]]
+ r3: Optional[Tuple[Any, ...]]
+ conn3 = await AsyncConnection.connect()
+ cur3 = await conn3.execute("select 3")
+ async with conn3.cursor() as cur3:
+ await cur3.execute("select 42")
+ r3 = await cur3.fetchone()
+ r3 and len(r3)
+
+
+def check_row_factories() -> None:
+ conn1 = connect(row_factory=rows.tuple_row)
+ v1: Tuple[Any, ...] = conn1.execute("").fetchall()[0]
+
+ conn2 = connect(row_factory=rows.dict_row)
+ v2: Dict[str, Any] = conn2.execute("").fetchall()[0]
+
+ conn3 = connect(row_factory=rows.class_row(Person))
+ v3: Person = conn3.execute("").fetchall()[0]
+
+ conn4 = connect(row_factory=rows.args_row(argsf))
+ v4: float = conn4.execute("").fetchall()[0]
+
+ conn5 = connect(row_factory=rows.kwargs_row(kwargsf))
+ v5: int = conn5.execute("").fetchall()[0]
+
+ v1, v2, v3, v4, v5
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..871f65d
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,179 @@
+import gc
+import re
+import sys
+import operator
+from typing import Callable, Optional, Tuple
+
+import pytest
+
+eur = "\u20ac"
+
+
+def check_libpq_version(got, want):
+ """
+ Verify if the libpq version is a version accepted.
+
+ This function is called on the tests marked with something like::
+
+ @pytest.mark.libpq(">= 12")
+
+ and skips the test if the requested version doesn't match what's loaded.
+ """
+ return check_version(got, want, "libpq", postgres_rule=True)
+
+
+def check_postgres_version(got, want):
+ """
+ Verify if the server version is a version accepted.
+
+ This function is called on the tests marked with something like::
+
+ @pytest.mark.pg(">= 12")
+
+ and skips the test if the server version doesn't match what expected.
+ """
+ return check_version(got, want, "PostgreSQL", postgres_rule=True)
+
+
+def check_version(got, want, whose_version, postgres_rule=True):
+ pred = VersionCheck.parse(want, postgres_rule=postgres_rule)
+ pred.whose = whose_version
+ return pred.get_skip_message(got)
+
+
+class VersionCheck:
+ """
+ Helper to compare a version number with a test spec.
+ """
+
+ def __init__(
+ self,
+ *,
+ skip: bool = False,
+ op: Optional[str] = None,
+ version_tuple: Tuple[int, ...] = (),
+ whose: str = "(wanted)",
+ postgres_rule: bool = False,
+ ):
+ self.skip = skip
+ self.op = op or "=="
+ self.version_tuple = version_tuple
+ self.whose = whose
+ # Treat 10.1 as 10.0.1
+ self.postgres_rule = postgres_rule
+
+ @classmethod
+ def parse(cls, spec: str, *, postgres_rule: bool = False) -> "VersionCheck":
+ # Parse a spec like "> 9.6", "skip < 21.2.0"
+ m = re.match(
+ r"""(?ix)
+ ^\s* (skip|only)?
+ \s* (==|!=|>=|<=|>|<)?
+ \s* (?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?
+ \s* $
+ """,
+ spec,
+ )
+ if m is None:
+ pytest.fail(f"bad wanted version spec: {spec}")
+
+ skip = (m.group(1) or "only").lower() == "skip"
+ op = m.group(2)
+ version_tuple = tuple(int(n) for n in m.groups()[2:] if n)
+
+ return cls(
+ skip=skip, op=op, version_tuple=version_tuple, postgres_rule=postgres_rule
+ )
+
+ def get_skip_message(self, version: Optional[int]) -> Optional[str]:
+ got_tuple = self._parse_int_version(version)
+
+ msg: Optional[str] = None
+ if self.skip:
+ if got_tuple:
+ if not self.version_tuple:
+ msg = f"skip on {self.whose}"
+ elif self._match_version(got_tuple):
+ msg = (
+ f"skip on {self.whose} {self.op}"
+ f" {'.'.join(map(str, self.version_tuple))}"
+ )
+ else:
+ if not got_tuple:
+ msg = f"only for {self.whose}"
+ elif not self._match_version(got_tuple):
+ if self.version_tuple:
+ msg = (
+ f"only for {self.whose} {self.op}"
+ f" {'.'.join(map(str, self.version_tuple))}"
+ )
+ else:
+ msg = f"only for {self.whose}"
+
+ return msg
+
+ _OP_NAMES = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq", "!=": "ne"}
+
+ def _match_version(self, got_tuple: Tuple[int, ...]) -> bool:
+ if not self.version_tuple:
+ return True
+
+ version_tuple = self.version_tuple
+ if self.postgres_rule and version_tuple and version_tuple[0] >= 10:
+ assert len(version_tuple) <= 2
+ version_tuple = version_tuple[:1] + (0,) + version_tuple[1:]
+
+ op: Callable[[Tuple[int, ...], Tuple[int, ...]], bool]
+ op = getattr(operator, self._OP_NAMES[self.op])
+ return op(got_tuple, version_tuple)
+
+ def _parse_int_version(self, version: Optional[int]) -> Tuple[int, ...]:
+ if version is None:
+ return ()
+ version, ver_fix = divmod(version, 100)
+ ver_maj, ver_min = divmod(version, 100)
+ return (ver_maj, ver_min, ver_fix)
+
+
+def gc_collect():
+ """
+ gc.collect(), but more insisting.
+ """
+ for i in range(3):
+ gc.collect()
+
+
+NO_COUNT_TYPES: Tuple[type, ...] = ()
+
+if sys.version_info[:2] == (3, 10):
+ # On my laptop there are occasional creations of a single one of these objects
+ # with empty content, which might be some Decimal caching.
+ # Keeping the guard as strict as possible, to be extended if other types
+ # or versions are necessary.
+ try:
+ from _contextvars import Context # type: ignore
+ except ImportError:
+ pass
+ else:
+ NO_COUNT_TYPES += (Context,)
+
+
+def gc_count() -> int:
+ """
+ len(gc.get_objects()), with subtleties.
+ """
+ if not NO_COUNT_TYPES:
+ return len(gc.get_objects())
+
+ # Note: not using a list comprehension because it pollutes the objects list.
+ rv = 0
+ for obj in gc.get_objects():
+ if isinstance(obj, NO_COUNT_TYPES):
+ continue
+ rv += 1
+
+ return rv
+
+
+async def alist(it):
+ return [i async for i in it]
diff --git a/tools/build/build_libpq.sh b/tools/build/build_libpq.sh
new file mode 100755
index 0000000..4cc79af
--- /dev/null
+++ b/tools/build/build_libpq.sh
@@ -0,0 +1,173 @@
+#!/bin/bash
+
+# Build a modern version of libpq and depending libs from source on Centos 5
+
+set -euo pipefail
+set -x
+
+# Last release: https://www.postgresql.org/ftp/source/
+# IMPORTANT! Change the cache key in packages.yml when upgrading libraries
+postgres_version="${LIBPQ_VERSION:-15.0}"
+
+# last release: https://www.openssl.org/source/
+openssl_version="${OPENSSL_VERSION:-1.1.1r}"
+
+# last release: https://openldap.org/software/download/
+ldap_version="2.6.3"
+
+# last release: https://github.com/cyrusimap/cyrus-sasl/releases
+sasl_version="2.1.28"
+
+export LIBPQ_BUILD_PREFIX=${LIBPQ_BUILD_PREFIX:-/tmp/libpq.build}
+
+if [[ -f "${LIBPQ_BUILD_PREFIX}/lib/libpq.so" ]]; then
+ echo "libpq already available: build skipped" >&2
+ exit 0
+fi
+
+source /etc/os-release
+
+case "$ID" in
+ centos)
+ yum update -y
+ yum install -y zlib-devel krb5-devel pam-devel
+ ;;
+
+ alpine)
+ apk upgrade
+ apk add --no-cache zlib-dev krb5-dev linux-pam-dev openldap-dev
+ ;;
+
+ *)
+ echo "$0: unexpected Linux distribution: '$ID'" >&2
+ exit 1
+ ;;
+esac
+
+if [ "$ID" == "centos" ]; then
+
+ # Build openssl if needed
+ openssl_tag="OpenSSL_${openssl_version//./_}"
+ openssl_dir="openssl-${openssl_tag}"
+ if [ ! -d "${openssl_dir}" ]; then curl -sL \
+ https://github.com/openssl/openssl/archive/${openssl_tag}.tar.gz \
+ | tar xzf -
+
+ cd "${openssl_dir}"
+
+ ./config --prefix=${LIBPQ_BUILD_PREFIX} --openssldir=${LIBPQ_BUILD_PREFIX} \
+ zlib -fPIC shared
+ make depend
+ make
+ else
+ cd "${openssl_dir}"
+ fi
+
+ # Install openssl
+ make install_sw
+ cd ..
+
+fi
+
+
+if [ "$ID" == "centos" ]; then
+
+ # Build libsasl2 if needed
+ # The system package (cyrus-sasl-devel) causes an amazing error on i686:
+ # "unsupported version 0 of Verneed record"
+ # https://github.com/pypa/manylinux/issues/376
+ sasl_tag="cyrus-sasl-${sasl_version}"
+ sasl_dir="cyrus-sasl-${sasl_tag}"
+ if [ ! -d "${sasl_dir}" ]; then
+ curl -sL \
+ https://github.com/cyrusimap/cyrus-sasl/archive/${sasl_tag}.tar.gz \
+ | tar xzf -
+
+ cd "${sasl_dir}"
+
+ autoreconf -i
+ ./configure --prefix=${LIBPQ_BUILD_PREFIX} \
+ CPPFLAGS=-I${LIBPQ_BUILD_PREFIX}/include/ LDFLAGS=-L${LIBPQ_BUILD_PREFIX}/lib
+ make
+ else
+ cd "${sasl_dir}"
+ fi
+
+ # Install libsasl2
+ # requires missing nroff to build
+ touch saslauthd/saslauthd.8
+ make install
+ cd ..
+
+fi
+
+
+if [ "$ID" == "centos" ]; then
+
+ # Build openldap if needed
+ ldap_tag="${ldap_version}"
+ ldap_dir="openldap-${ldap_tag}"
+ if [ ! -d "${ldap_dir}" ]; then
+ curl -sL \
+ https://www.openldap.org/software/download/OpenLDAP/openldap-release/openldap-${ldap_tag}.tgz \
+ | tar xzf -
+
+ cd "${ldap_dir}"
+
+ ./configure --prefix=${LIBPQ_BUILD_PREFIX} --enable-backends=no --enable-null \
+ CPPFLAGS=-I${LIBPQ_BUILD_PREFIX}/include/ LDFLAGS=-L${LIBPQ_BUILD_PREFIX}/lib
+
+ make depend
+ make -C libraries/liblutil/
+ make -C libraries/liblber/
+ make -C libraries/libldap/
+ else
+ cd "${ldap_dir}"
+ fi
+
+ # Install openldap
+ make -C libraries/liblber/ install
+ make -C libraries/libldap/ install
+ make -C include/ install
+ chmod +x ${LIBPQ_BUILD_PREFIX}/lib/{libldap,liblber}*.so*
+ cd ..
+
+fi
+
+
+# Build libpq if needed
+postgres_tag="REL_${postgres_version//./_}"
+postgres_dir="postgres-${postgres_tag}"
+if [ ! -d "${postgres_dir}" ]; then
+ curl -sL \
+ https://github.com/postgres/postgres/archive/${postgres_tag}.tar.gz \
+ | tar xzf -
+
+ cd "${postgres_dir}"
+
+ # Match the default unix socket dir default with what defined on Ubuntu and
+ # Red Hat, which seems the most common location
+ sed -i 's|#define DEFAULT_PGSOCKET_DIR .*'\
+'|#define DEFAULT_PGSOCKET_DIR "/var/run/postgresql"|' \
+ src/include/pg_config_manual.h
+
+ # Often needed, but currently set by the workflow
+ # export LD_LIBRARY_PATH="${LIBPQ_BUILD_PREFIX}/lib"
+
+ ./configure --prefix=${LIBPQ_BUILD_PREFIX} --sysconfdir=/etc/postgresql-common \
+ --without-readline --with-gssapi --with-openssl --with-pam --with-ldap \
+ CPPFLAGS=-I${LIBPQ_BUILD_PREFIX}/include/ LDFLAGS=-L${LIBPQ_BUILD_PREFIX}/lib
+ make -C src/interfaces/libpq
+ make -C src/bin/pg_config
+ make -C src/include
+else
+ cd "${postgres_dir}"
+fi
+
+# Install libpq
+make -C src/interfaces/libpq install
+make -C src/bin/pg_config install
+make -C src/include install
+cd ..
+
+find ${LIBPQ_BUILD_PREFIX} -name \*.so.\* -type f -exec strip --strip-unneeded {} \;
diff --git a/tools/build/build_macos_arm64.sh b/tools/build/build_macos_arm64.sh
new file mode 100755
index 0000000..f8c2fd7
--- /dev/null
+++ b/tools/build/build_macos_arm64.sh
@@ -0,0 +1,93 @@
+#!/bin/bash
+
+# Build psycopg-binary wheel packages for Apple M1 (cpNNN-macosx_arm64)
+#
+# This script is designed to run on Scaleway Apple Silicon machines.
+#
+# The script cannot be run as sudo (installing brew fails), but requires sudo,
+# so it can pretty much only be executed by a sudo user as it is.
+
+set -euo pipefail
+set -x
+
+python_versions="3.8.10 3.9.13 3.10.5 3.11.0"
+pg_version=15
+
+# Move to the root of the project
+dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+cd "${dir}/../../"
+
+# Add /usr/local/bin to the path. It seems it's not, in non-interactive sessions
+if ! (echo $PATH | grep -q '/usr/local/bin'); then
+ export PATH=/usr/local/bin:$PATH
+fi
+
+# Install brew, if necessary. Otherwise just make sure it's in the path
+if [[ -x /opt/homebrew/bin/brew ]]; then
+ eval "$(/opt/homebrew/bin/brew shellenv)"
+else
+ command -v brew > /dev/null || (
+ # Not necessary: already installed
+ # xcode-select --install
+ NONINTERACTIVE=1 /bin/bash -c "$(curl -fsSL \
+ https://raw.githubusercontent.com/Homebrew/install/master/install.sh)"
+ )
+ eval "$(/opt/homebrew/bin/brew shellenv)"
+fi
+
+export PGDATA=/opt/homebrew/var/postgresql@${pg_version}
+
+# Install PostgreSQL, if necessary
+command -v pg_config > /dev/null || (
+ brew install postgresql@${pg_version}
+)
+
+# After PostgreSQL 15, the bin path is not in the path.
+export PATH=$(ls -d1 /opt/homebrew/Cellar/postgresql@${pg_version}/*/bin):$PATH
+
+# Make sure the server is running
+
+# Currently not working
+# brew services start postgresql@${pg_version}
+
+if ! pg_ctl status; then
+ pg_ctl -l /opt/homebrew/var/log/postgresql@${pg_version}.log start
+fi
+
+
+# Install the Python versions we want to build
+for ver3 in $python_versions; do
+ ver2=$(echo $ver3 | sed 's/\([^\.]*\)\(\.[^\.]*\)\(.*\)/\1\2/')
+ command -v python${ver2} > /dev/null || (
+ (cd /tmp &&
+ curl -fsSl -O \
+ https://www.python.org/ftp/python/${ver3}/python-${ver3}-macos11.pkg)
+ sudo installer -pkg /tmp/python-${ver3}-macos11.pkg -target /
+ )
+done
+
+# Create a virtualenv where to work
+if [[ ! -x .venv/bin/python ]]; then
+ python3 -m venv .venv
+fi
+
+source .venv/bin/activate
+pip install cibuildwheel
+
+# Create the psycopg_binary source package
+rm -rf psycopg_binary
+python tools/build/copy_to_binary.py
+
+# Build the binary packages
+export CIBW_PLATFORM=macos
+export CIBW_ARCHS=arm64
+export CIBW_BUILD='cp{38,39,310,311}-*'
+export CIBW_TEST_REQUIRES="./psycopg[test] ./psycopg_pool"
+export CIBW_TEST_COMMAND="pytest {project}/tests -m 'not slow and not flakey' --color yes"
+
+export PSYCOPG_IMPL=binary
+export PSYCOPG_TEST_DSN="dbname=postgres"
+export PSYCOPG_TEST_WANT_LIBPQ_BUILD=">= ${pg_version}"
+export PSYCOPG_TEST_WANT_LIBPQ_IMPORT=">= ${pg_version}"
+
+cibuildwheel psycopg_binary
diff --git a/tools/build/ci_test.sh b/tools/build/ci_test.sh
new file mode 100755
index 0000000..d1d2ee4
--- /dev/null
+++ b/tools/build/ci_test.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+
+# Run the tests in Github Action
+#
+# Failed tests run up to three times, to take into account flakey tests.
+# Of course the random generator is not re-seeded between runs, in order to
+# repeat the same result.
+
+set -euo pipefail
+set -x
+
+# Assemble a markers expression from the MARKERS and NOT_MARKERS env vars
+markers=""
+for m in ${MARKERS:-}; do
+ [[ "$markers" != "" ]] && markers="$markers and"
+ markers="$markers $m"
+done
+for m in ${NOT_MARKERS:-}; do
+ [[ "$markers" != "" ]] && markers="$markers and"
+ markers="$markers not $m"
+done
+
+pytest="python -bb -m pytest --color=yes"
+
+$pytest -m "$markers" "$@" && exit 0
+
+$pytest -m "$markers" --lf --randomly-seed=last "$@" && exit 0
+
+$pytest -m "$markers" --lf --randomly-seed=last "$@"
diff --git a/tools/build/copy_to_binary.py b/tools/build/copy_to_binary.py
new file mode 100755
index 0000000..7cab25c
--- /dev/null
+++ b/tools/build/copy_to_binary.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+
+# Create the psycopg-binary package by renaming and patching psycopg-c
+
+import os
+import re
+import shutil
+from pathlib import Path
+from typing import Union
+
+curdir = Path(__file__).parent
+pdir = curdir / "../.."
+target = pdir / "psycopg_binary"
+
+if target.exists():
+ raise Exception(f"path {target} already exists")
+
+
+def sed_i(pattern: str, repl: str, filename: Union[str, Path]) -> None:
+ with open(filename, "rb") as f:
+ data = f.read()
+ newdata = re.sub(pattern.encode("utf8"), repl.encode("utf8"), data)
+ if newdata != data:
+ with open(filename, "wb") as f:
+ f.write(newdata)
+
+
+shutil.copytree(pdir / "psycopg_c", target)
+shutil.move(str(target / "psycopg_c"), str(target / "psycopg_binary"))
+shutil.move(str(target / "README-binary.rst"), str(target / "README.rst"))
+sed_i("psycopg-c", "psycopg-binary", target / "setup.cfg")
+sed_i(
+ r"__impl__\s*=.*", '__impl__ = "binary"', target / "psycopg_binary/pq.pyx"
+)
+for dirpath, dirnames, filenames in os.walk(target):
+ for filename in filenames:
+ if os.path.splitext(filename)[1] not in (".pyx", ".pxd", ".py"):
+ continue
+ sed_i(r"\bpsycopg_c\b", "psycopg_binary", Path(dirpath) / filename)
diff --git a/tools/build/print_so_versions.sh b/tools/build/print_so_versions.sh
new file mode 100755
index 0000000..a3c4ecd
--- /dev/null
+++ b/tools/build/print_so_versions.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+# Take a .so file as input and print the Debian packages and versions of the
+# libraries it links.
+
+set -euo pipefail
+# set -x
+
+source /etc/os-release
+
+sofile="$1"
+
+case "$ID" in
+ alpine)
+ depfiles=$( (ldd "$sofile" 2>/dev/null || true) | grep '=>' | sed 's/.*=> \(.*\) (.*)/\1/')
+ (for depfile in $depfiles; do
+ echo "$(basename "$depfile") => $(apk info --who-owns "${depfile}" | awk '{print $(NF)}')"
+ done) | sort | uniq
+ ;;
+
+ debian)
+ depfiles=$(ldd "$sofile" | grep '=>' | sed 's/.*=> \(.*\) (.*)/\1/')
+ (for depfile in $depfiles; do
+ pkgname=$(dpkg -S "${depfile}" | sed 's/\(\): .*/\1/')
+ dpkg -l "${pkgname}" | grep '^ii' | awk '{print $2 " => " $3}'
+ done) | sort | uniq
+ ;;
+
+ centos)
+ echo "TODO!"
+ ;;
+
+ *)
+ echo "$0: unexpected Linux distribution: '$ID'" >&2
+ exit 1
+ ;;
+esac
diff --git a/tools/build/run_build_macos_arm64.sh b/tools/build/run_build_macos_arm64.sh
new file mode 100755
index 0000000..f5ae617
--- /dev/null
+++ b/tools/build/run_build_macos_arm64.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+# Build psycopg-binary wheel packages for Apple M1 (cpNNN-macosx_arm64)
+#
+# This script is designed to run on a local machine: it will clone the repos
+# remotely and execute the `build_macos_arm64.sh` script remotely, then will
+# download the built packages. A tag to build must be specified.
+#
+# In order to run the script, the `m1` host must be specified in
+# `~/.ssh/config`; for instance:
+#
+# Host m1
+# User m1
+# HostName 1.2.3.4
+
+set -euo pipefail
+# set -x
+
+tag=${1:-}
+
+if [[ ! "${tag}" ]]; then
+ echo "Usage: $0 TAG" >&2
+ exit 2
+fi
+
+rdir=psycobuild
+
+# Clone the repos
+ssh m1 rm -rf "${rdir}"
+ssh m1 git clone https://github.com/psycopg/psycopg.git --branch ${tag} "${rdir}"
+
+# Allow sudoing without password, to allow brew to install
+ssh -t m1 bash -c \
+ 'test -f /etc/sudoers.d/m1 || echo "m1 ALL=(ALL) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/m1'
+
+# Build the wheel packages
+ssh m1 "${rdir}/tools/build/build_macos_arm64.sh"
+
+# Transfer the packages locally
+scp -r "m1:${rdir}/wheelhouse" .
diff --git a/tools/build/strip_wheel.sh b/tools/build/strip_wheel.sh
new file mode 100755
index 0000000..bfcd302
--- /dev/null
+++ b/tools/build/strip_wheel.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+# Strip symbols inplace from the libraries in a zip archive.
+#
+# Stripping symbols is beneficial (reduction of 30% of the final package, >
+# %90% of the installed libraries. However just running `auditwheel repair
+# --strip` breaks some of the libraries included from the system, which fail at
+# import with errors such as "ELF load command address/offset not properly
+# aligned".
+#
+# System libraries are already pretty stripped. Ours go around 24Mb -> 1.5Mb...
+#
+# This script is designed to run on a wheel archive before auditwheel.
+
+set -euo pipefail
+# set -x
+
+source /etc/os-release
+dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+
+wheel=$(realpath "$1")
+shift
+
+tmpdir=$(mktemp -d)
+trap "rm -r ${tmpdir}" EXIT
+
+cd "${tmpdir}"
+python -m zipfile -e "${wheel}" .
+
+echo "
+Libs before:"
+# Busybox doesn't have "find -ls"
+find . -name \*.so | xargs ls -l
+
+# On Debian, print the package versions libraries come from
+echo "
+Dependencies versions of '_psycopg.so' library:"
+"${dir}/print_so_versions.sh" "$(find . -name \*_psycopg\*.so)"
+
+find . -name \*.so -exec strip "$@" {} \;
+
+echo "
+Libs after:"
+find . -name \*.so | xargs ls -l
+
+python -m zipfile -c ${wheel} *
+
+cd -
diff --git a/tools/build/wheel_linux_before_all.sh b/tools/build/wheel_linux_before_all.sh
new file mode 100755
index 0000000..663e3ef
--- /dev/null
+++ b/tools/build/wheel_linux_before_all.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+# Configure the libraries needed to build wheel packages on linux.
+# This script is designed to be used by cibuildwheel as CIBW_BEFORE_ALL_LINUX
+
+set -euo pipefail
+set -x
+
+dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+
+source /etc/os-release
+
+# Install PostgreSQL development files.
+case "$ID" in
+ alpine)
+ # tzdata is required for datetime tests.
+ apk update
+ apk add --no-cache tzdata
+ "${dir}/build_libpq.sh" > /dev/null
+ ;;
+
+ debian)
+ # Note that the pgdg doesn't have an aarch64 repository so wheels are
+ # build with the libpq packaged with Debian 9, which is 9.6.
+ if [ "$AUDITWHEEL_ARCH" != 'aarch64' ]; then
+ echo "deb http://apt.postgresql.org/pub/repos/apt $VERSION_CODENAME-pgdg main" \
+ > /etc/apt/sources.list.d/pgdg.list
+ # TODO: On 2021-11-09 curl fails on 'ppc64le' with:
+ # curl: (60) SSL certificate problem: certificate has expired
+ # Test again later if -k can be removed.
+ curl -skf https://www.postgresql.org/media/keys/ACCC4CF8.asc \
+ > /etc/apt/trusted.gpg.d/postgresql.asc
+ fi
+
+ apt-get update
+ apt-get -y upgrade
+ apt-get -y install libpq-dev
+ ;;
+
+ centos)
+ "${dir}/build_libpq.sh" > /dev/null
+ ;;
+
+ *)
+ echo "$0: unexpected Linux distribution: '$ID'" >&2
+ exit 1
+ ;;
+esac
diff --git a/tools/build/wheel_macos_before_all.sh b/tools/build/wheel_macos_before_all.sh
new file mode 100755
index 0000000..285a063
--- /dev/null
+++ b/tools/build/wheel_macos_before_all.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Configure the environment needed to build wheel packages on Mac OS.
+# This script is designed to be used by cibuildwheel as CIBW_BEFORE_ALL_MACOS
+
+set -euo pipefail
+set -x
+
+dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+
+brew update
+brew install gnu-sed postgresql@14
+# Fetch 14.1 if 14.0 is still the default version
+brew reinstall postgresql
+
+# Start the database for testing
+brew services start postgresql
+
+# Wait for postgres to come up
+for i in $(seq 10 -1 0); do
+ eval pg_isready && break
+ if [ $i == 0 ]; then
+ echo "PostgreSQL service not ready, giving up"
+ exit 1
+ fi
+ echo "PostgreSQL service not ready, waiting a bit, attempts left: $i"
+ sleep 5
+done
diff --git a/tools/build/wheel_win32_before_build.bat b/tools/build/wheel_win32_before_build.bat
new file mode 100644
index 0000000..fd35f5d
--- /dev/null
+++ b/tools/build/wheel_win32_before_build.bat
@@ -0,0 +1,3 @@
+@echo on
+pip install delvewheel
+choco upgrade postgresql
diff --git a/tools/bump_version.py b/tools/bump_version.py
new file mode 100755
index 0000000..50dbe0b
--- /dev/null
+++ b/tools/bump_version.py
@@ -0,0 +1,310 @@
+#!/usr/bin/env python
+"""Bump the version number of the project.
+"""
+
+from __future__ import annotations
+
+import re
+import sys
+import logging
+import subprocess as sp
+from enum import Enum
+from pathlib import Path
+from argparse import ArgumentParser, Namespace
+from functools import cached_property
+from dataclasses import dataclass
+
+from packaging.version import parse as parse_version, Version
+
+PROJECT_DIR = Path(__file__).parent.parent
+
+logger = logging.getLogger()
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
+
+
+@dataclass
+class Package:
+ name: str
+ version_files: list[Path]
+ history_file: Path
+ tag_format: str
+
+ def __post_init__(self) -> None:
+ packages[self.name] = self
+
+
+packages: dict[str, Package] = {}
+
+Package(
+ name="psycopg",
+ version_files=[
+ PROJECT_DIR / "psycopg/psycopg/version.py",
+ PROJECT_DIR / "psycopg_c/psycopg_c/version.py",
+ ],
+ history_file=PROJECT_DIR / "docs/news.rst",
+ tag_format="{version}",
+)
+
+Package(
+ name="psycopg_pool",
+ version_files=[PROJECT_DIR / "psycopg_pool/psycopg_pool/version.py"],
+ history_file=PROJECT_DIR / "docs/news_pool.rst",
+ tag_format="pool-{version}",
+)
+
+
+class Bumper:
+ def __init__(self, package: Package, *, bump_level: str | BumpLevel):
+ self.package = package
+ self.bump_level = BumpLevel(bump_level)
+
+ self._version_regex = re.compile(
+ r"""(?ix)
+ ^
+ (?P<pre>__version__\s*=\s*(?P<quote>["']))
+ (?P<ver>[^'"]+)
+ (?P<post>(?P=quote)\s*(?:\#.*)?)
+ $
+ """
+ )
+
+ @cached_property
+ def current_version(self) -> Version:
+ versions = set(
+ self._parse_version_from_file(f) for f in self.package.version_files
+ )
+ if len(versions) > 1:
+ raise ValueError(
+ f"inconsistent versions ({', '.join(map(str, sorted(versions)))})"
+ f" in {self.package.version_files}"
+ )
+
+ return versions.pop()
+
+ @cached_property
+ def want_version(self) -> Version:
+ current = self.current_version
+ parts = [current.major, current.minor, current.micro, current.dev or 0]
+
+ match self.bump_level:
+ case BumpLevel.MAJOR:
+ # 1.2.3 -> 2.0.0
+ parts[0] += 1
+ parts[1] = parts[2] = parts[3] = 0
+ case BumpLevel.MINOR:
+ # 1.2.3 -> 1.3.0
+ parts[1] += 1
+ parts[2] = parts[3] = 0
+ case BumpLevel.PATCH:
+ # 1.2.3 -> 1.2.4
+ # 1.2.3.dev4 -> 1.2.3
+ if parts[3] == 0:
+ parts[2] += 1
+ else:
+ parts[3] = 0
+ case BumpLevel.DEV:
+ # 1.2.3 -> 1.2.4.dev1
+ # 1.2.3.dev1 -> 1.2.3.dev2
+ if parts[3] == 0:
+ parts[2] += 1
+ parts[3] += 1
+
+ sparts = [str(part) for part in parts[:3]]
+ if parts[3]:
+ sparts.append(f"dev{parts[3]}")
+ return Version(".".join(sparts))
+
+ def update_files(self) -> None:
+ for f in self.package.version_files:
+ self._update_version_in_file(f, self.want_version)
+
+ if self.bump_level != BumpLevel.DEV:
+ self._update_history_file(self.package.history_file, self.want_version)
+
+ def commit(self) -> None:
+ logger.debug("committing version changes")
+ msg = f"""\
+chore: bump {self.package.name} package version to {self.want_version}
+"""
+ files = self.package.version_files + [self.package.history_file]
+ cmdline = ["git", "commit", "-m", msg] + list(map(str, files))
+ sp.check_call(cmdline)
+
+ def create_tag(self) -> None:
+ logger.debug("tagging version %s", self.want_version)
+ tag_name = self.package.tag_format.format(version=self.want_version)
+ changes = self._get_changes_lines(
+ self.package.history_file,
+ self.want_version,
+ )
+ msg = f"""\
+{self.package.name} {self.want_version} released
+
+{''.join(changes)}
+"""
+ cmdline = ["git", "tag", "-a", "-s", "-m", msg, tag_name]
+ sp.check_call(cmdline)
+
+ def _parse_version_from_file(self, fp: Path) -> Version:
+ logger.debug("looking for version in %s", fp)
+ matches = []
+ with fp.open() as f:
+ for line in f:
+ m = self._version_regex.match(line)
+ if m:
+ matches.append(m)
+
+ if not matches:
+ raise ValueError(f"no version found in {fp}")
+ elif len(matches) > 1:
+ raise ValueError(f"more than one version found in {fp}")
+
+ vs = parse_version(matches[0].group("ver"))
+ assert isinstance(vs, Version)
+ return vs
+
+ def _update_version_in_file(self, fp: Path, version: Version) -> None:
+ logger.debug("upgrading version to %s in %s", version, fp)
+ lines = []
+ with fp.open() as f:
+ for line in f:
+ if self._version_regex.match(line):
+ line = self._version_regex.sub(f"\\g<pre>{version}\\g<post>", line)
+ lines.append(line)
+
+ with fp.open("w") as f:
+ for line in lines:
+ f.write(line)
+
+ def _update_history_file(self, fp: Path, version: Version) -> None:
+ logger.debug("upgrading history file %s", fp)
+ with fp.open() as f:
+ lines = f.readlines()
+
+ vln: int = -1
+ lns = self._find_lines(
+ r"^[^\s]+ " + re.escape(str(version)) + r"\s*\(unreleased\)?$", lines
+ )
+ assert len(lns) <= 1
+ if len(lns) == 1:
+ vln = lns[0]
+ lines[vln] = lines[vln].rsplit(None, 1)[0]
+ lines[vln + 1] = lines[vln + 1][0] * len(lines[lns[0]])
+
+ lns = self._find_lines("^Future", lines)
+ assert len(lns) <= 1
+ if len(lns) == 1:
+ del lines[lns[0] : lns[0] + 3]
+ if vln > lns[0]:
+ vln -= 3
+
+ lns = self._find_lines("^Current", lines)
+ assert len(lns) <= 1
+ if len(lns) == 1 and vln >= 0:
+ clines = lines[lns[0] : lns[0] + 3]
+ del lines[lns[0] : lns[0] + 3]
+ if vln > lns[0]:
+ vln -= 3
+ lines[vln:vln] = clines
+
+ with fp.open("w") as f:
+ for line in lines:
+ f.write(line)
+ if not line.endswith("\n"):
+ f.write("\n")
+
+ def _get_changes_lines(self, fp: Path, version: Version) -> list[str]:
+ with fp.open() as f:
+ lines = f.readlines()
+
+ lns = self._find_lines(r"^[^\s]+ " + re.escape(str(version)), lines)
+ assert len(lns) == 1
+ start = end = lns[0] + 3
+ while lines[end].rstrip():
+ end += 1
+
+ return lines[start:end]
+
+ def _find_lines(self, pattern: str, lines: list[str]) -> list[int]:
+ rv = []
+ rex = re.compile(pattern)
+ for i, line in enumerate(lines):
+ if rex.match(line):
+ rv.append(i)
+
+ return rv
+
+
+def main() -> int | None:
+ opt = parse_cmdline()
+ logger.setLevel(opt.loglevel)
+ bumper = Bumper(packages[opt.package], bump_level=opt.level)
+ logger.info("current version: %s", bumper.current_version)
+ logger.info("bumping to version: %s", bumper.want_version)
+ if not opt.dry_run:
+ bumper.update_files()
+ bumper.commit()
+ if opt.level != BumpLevel.DEV:
+ bumper.create_tag()
+
+ return 0
+
+
+class BumpLevel(str, Enum):
+ MAJOR = "major"
+ MINOR = "minor"
+ PATCH = "patch"
+ DEV = "dev"
+
+
+def parse_cmdline() -> Namespace:
+ parser = ArgumentParser(description=__doc__)
+
+ parser.add_argument(
+ "--level",
+ choices=[level.value for level in BumpLevel],
+ default=BumpLevel.PATCH.value,
+ type=BumpLevel,
+ help="the level to bump [default: %(default)s]",
+ )
+
+ parser.add_argument(
+ "--package",
+ choices=list(packages.keys()),
+ default="psycopg",
+ help="the package to bump version [default: %(default)s]",
+ )
+
+ parser.add_argument(
+ "-n",
+ "--dry-run",
+ help="Just pretend",
+ action="store_true",
+ )
+
+ g = parser.add_mutually_exclusive_group()
+ g.add_argument(
+ "-q",
+ "--quiet",
+ help="Talk less",
+ dest="loglevel",
+ action="store_const",
+ const=logging.WARN,
+ default=logging.INFO,
+ )
+ g.add_argument(
+ "-v",
+ "--verbose",
+ help="Talk more",
+ dest="loglevel",
+ action="store_const",
+ const=logging.DEBUG,
+ default=logging.INFO,
+ )
+ opt = parser.parse_args()
+
+ return opt
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/tools/update_backer.py b/tools/update_backer.py
new file mode 100755
index 0000000..0088527
--- /dev/null
+++ b/tools/update_backer.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python3
+r"""Add or edit github users in the backers file
+"""
+
+import sys
+import logging
+import requests
+from pathlib import Path
+from ruamel.yaml import YAML # pip install ruamel.yaml
+
+logger = logging.getLogger()
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
+
+
+def fetch_user(username):
+ logger.info("fetching %s", username)
+ resp = requests.get(
+ f"https://api.github.com/users/{username}",
+ headers={"Accept": "application/vnd.github.v3+json"},
+ )
+ resp.raise_for_status()
+ return resp.json()
+
+
+def get_user_data(data):
+ """
+ Get the data to save from the request data
+ """
+ out = {
+ "username": data["login"],
+ "avatar": data["avatar_url"],
+ "name": data["name"],
+ }
+ if data["blog"]:
+ website = data["blog"]
+ if not website.startswith("http"):
+ website = "http://" + website
+
+ out["website"] = website
+
+ return out
+
+
+def add_entry(opt, filedata, username):
+ userdata = get_user_data(fetch_user(username))
+ if opt.top:
+ userdata["tier"] = "top"
+
+ filedata.append(userdata)
+
+
+def update_entry(opt, filedata, entry):
+ # entry is an username or an user entry daat
+ if isinstance(entry, str):
+ username = entry
+ entry = [e for e in filedata if e["username"] == username]
+ if not entry:
+ raise Exception(f"{username} not found")
+ entry = entry[0]
+ else:
+ username = entry["username"]
+
+ userdata = get_user_data(fetch_user(username))
+ for k, v in userdata.items():
+ if entry.get("keep_" + k):
+ continue
+ entry[k] = v
+
+
+def main():
+ opt = parse_cmdline()
+ logger.info("reading %s", opt.file)
+ yaml = YAML(typ="rt")
+ filedata = yaml.load(opt.file)
+
+ for username in opt.add or ():
+ add_entry(opt, filedata, username)
+
+ for username in opt.update or ():
+ update_entry(opt, filedata, username)
+
+ if opt.update_all:
+ for entry in filedata:
+ update_entry(opt, filedata, entry)
+
+ # yamllint happy
+ yaml.explicit_start = True
+ logger.info("writing %s", opt.file)
+ yaml.dump(filedata, opt.file)
+
+
+def parse_cmdline():
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--file",
+ help="the file to update [default: %(default)s]",
+ default=Path(__file__).parent.parent / "BACKERS.yaml",
+ type=Path,
+ )
+ parser.add_argument(
+ "--add",
+ metavar="USERNAME",
+ nargs="+",
+ help="add USERNAME to the backers",
+ )
+
+ parser.add_argument(
+ "--top",
+ action="store_true",
+ help="add to the top tier",
+ )
+
+ parser.add_argument(
+ "--update",
+ metavar="USERNAME",
+ nargs="+",
+ help="update USERNAME data",
+ )
+
+ parser.add_argument(
+ "--update-all",
+ action="store_true",
+ help="update all the existing backers data",
+ )
+
+ opt = parser.parse_args()
+
+ return opt
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/tools/update_errors.py b/tools/update_errors.py
new file mode 100755
index 0000000..638d352
--- /dev/null
+++ b/tools/update_errors.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python
+# type: ignore
+"""
+Generate per-sqlstate errors from PostgreSQL source code.
+
+The script can be run at a new PostgreSQL release to refresh the module.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+import sys
+import logging
+from urllib.request import urlopen
+from collections import defaultdict, namedtuple
+
+from psycopg.errors import get_base_exception
+
+logger = logging.getLogger()
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
+
+
+def main():
+ classes, errors = fetch_errors("9.6 10 11 12 13 14 15".split())
+
+ fn = os.path.dirname(__file__) + "/../psycopg/psycopg/errors.py"
+ update_file(fn, generate_module_data(classes, errors))
+
+ fn = os.path.dirname(__file__) + "/../docs/api/errors.rst"
+ update_file(fn, generate_docs_data(classes, errors))
+
+
+def parse_errors_txt(url):
+ classes = {}
+ errors = defaultdict(dict)
+
+ page = urlopen(url)
+ for line in page.read().decode("ascii").splitlines():
+ # Strip comments and skip blanks
+ line = line.split("#")[0].strip()
+ if not line:
+ continue
+
+ # Parse a section
+ m = re.match(r"Section: (Class (..) - .+)", line)
+ if m:
+ label, class_ = m.groups()
+ classes[class_] = label
+ continue
+
+ # Parse an error
+ m = re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line)
+ if m:
+ sqlstate, macro, spec = m.groups()
+ # skip sqlstates without specs as they are not publicly visible
+ if not spec:
+ continue
+ errlabel = spec.upper()
+ errors[class_][sqlstate] = errlabel
+ continue
+
+ # We don't expect anything else
+ raise ValueError("unexpected line:\n%s" % line)
+
+ return classes, errors
+
+
+errors_txt_url = (
+ "http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob_plain;"
+ "f=src/backend/utils/errcodes.txt;hb=%s"
+)
+
+
+Error = namedtuple("Error", "sqlstate errlabel clsname basename")
+
+
+def fetch_errors(versions):
+ classes = {}
+ errors = defaultdict(dict)
+
+ for version in versions:
+ logger.info("fetching errors from version %s", version)
+ tver = tuple(map(int, version.split()[0].split(".")))
+ tag = "%s%s_STABLE" % (
+ (tver[0] >= 10 and "REL_" or "REL"),
+ version.replace(".", "_"),
+ )
+ c1, e1 = parse_errors_txt(errors_txt_url % tag)
+ classes.update(c1)
+
+ for c, cerrs in e1.items():
+ errors[c].update(cerrs)
+
+ # clean up data
+
+ # success and warning - never raised
+ del classes["00"]
+ del classes["01"]
+ del errors["00"]
+ del errors["01"]
+
+ specific = {
+ "38002": "ModifyingSqlDataNotPermittedExt",
+ "38003": "ProhibitedSqlStatementAttemptedExt",
+ "38004": "ReadingSqlDataNotPermittedExt",
+ "39004": "NullValueNotAllowedExt",
+ "XX000": "InternalError_",
+ }
+
+ seen = set(
+ """
+ Error Warning InterfaceError DataError DatabaseError ProgrammingError
+ IntegrityError InternalError NotSupportedError OperationalError
+ """.split()
+ )
+
+ for c, cerrs in errors.items():
+ for sqstate, errlabel in list(cerrs.items()):
+ if sqstate in specific:
+ clsname = specific[sqstate]
+ else:
+ clsname = errlabel.title().replace("_", "")
+ if clsname in seen:
+ raise Exception("class already existing: %s" % clsname)
+ seen.add(clsname)
+
+ basename = get_base_exception(sqstate).__name__
+ cerrs[sqstate] = Error(sqstate, errlabel, clsname, basename)
+
+ return classes, errors
+
+
+def generate_module_data(classes, errors):
+ yield ""
+
+ for clscode, clslabel in sorted(classes.items()):
+ yield f"""
+# {clslabel}
+"""
+ for _, e in sorted(errors[clscode].items()):
+ yield f"""\
+class {e.clsname}({e.basename},
+ code={e.sqlstate!r}, name={e.errlabel!r}):
+ pass
+"""
+ yield ""
+
+
+def generate_docs_data(classes, errors):
+ Line = namedtuple("Line", "colstate colexc colbase, sqlstate")
+ lines = [Line("SQLSTATE", "Exception", "Base exception", None)]
+
+ for clscode in sorted(classes):
+ for _, error in sorted(errors[clscode].items()):
+ lines.append(
+ Line(
+ f"``{error.sqlstate}``",
+ f"`!{error.clsname}`",
+ f"`!{error.basename}`",
+ error.sqlstate,
+ )
+ )
+
+ widths = [max(len(line[c]) for line in lines) for c in range(3)]
+ h = Line(*(["=" * w for w in widths] + [None]))
+ lines.insert(0, h)
+ lines.insert(2, h)
+ lines.append(h)
+
+ h1 = "-" * (sum(widths) + len(widths) - 1)
+ sqlclass = None
+
+ yield ""
+ for line in lines:
+ cls = line.sqlstate[:2] if line.sqlstate else None
+ if cls and cls != sqlclass:
+ yield re.sub(r"(Class\s+[^\s]+)", r"**\1**", classes[cls])
+ yield h1
+ sqlclass = cls
+
+ yield (
+ "%-*s %-*s %-*s"
+ % (
+ widths[0],
+ line.colstate,
+ widths[1],
+ line.colexc,
+ widths[2],
+ line.colbase,
+ )
+ ).rstrip()
+
+ yield ""
+
+
+def update_file(fn, new_lines):
+ logger.info("updating %s", fn)
+
+ with open(fn, "r") as f:
+ lines = f.read().splitlines()
+
+ istart, iend = [
+ i
+ for i, line in enumerate(lines)
+ if re.match(r"\s*(#|\.\.)\s*autogenerated:\s+(start|end)", line)
+ ]
+
+ lines[istart + 1 : iend] = new_lines
+
+ with open(fn, "w") as f:
+ for line in lines:
+ f.write(line + "\n")
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/tools/update_oids.py b/tools/update_oids.py
new file mode 100755
index 0000000..df4f969
--- /dev/null
+++ b/tools/update_oids.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python
+"""
+Update the maps of builtin types and names.
+
+This script updates some of the files in psycopg source code with data read
+from a database catalog.
+
+Hint: use docker to upgrade types from a new version in isolation. Run:
+
+ docker run --rm -p 11111:5432 --name pg -e POSTGRES_PASSWORD=password postgres:TAG
+
+with a specified version tag, and then query it using:
+
+ %(prog)s "host=localhost port=11111 user=postgres password=password"
+"""
+
+import re
+import argparse
+import subprocess as sp
+from typing import List
+from pathlib import Path
+from typing_extensions import TypeAlias
+
+import psycopg
+from psycopg.rows import TupleRow
+from psycopg.crdb import CrdbConnection
+
+Connection: TypeAlias = psycopg.Connection[TupleRow]
+
+ROOT = Path(__file__).parent.parent
+
+
+def main() -> None:
+ opt = parse_cmdline()
+ conn = psycopg.connect(opt.dsn, autocommit=True)
+
+ if CrdbConnection.is_crdb(conn):
+ conn = CrdbConnection.connect(opt.dsn, autocommit=True)
+ update_crdb_python_oids(conn)
+ else:
+ update_python_oids(conn)
+ update_cython_oids(conn)
+
+
+def update_python_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg/psycopg/postgres.py"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_py_types(conn))
+ lines.extend(get_py_ranges(conn))
+ lines.extend(get_py_multiranges(conn))
+
+ update_file(fn, lines)
+ sp.check_call(["black", "-q", fn])
+
+
+def update_cython_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_cython_oids(conn))
+
+ update_file(fn, lines)
+
+
+def update_crdb_python_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg/psycopg/crdb/_types.py"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_py_types(conn))
+
+ update_file(fn, lines)
+ sp.check_call(["black", "-q", fn])
+
+
+def get_version_comment(conn: Connection) -> List[str]:
+ if conn.info.vendor == "PostgreSQL":
+ # Assume PG > 10
+ num = conn.info.server_version
+ version = f"{num // 10000}.{num % 100}"
+ elif conn.info.vendor == "CockroachDB":
+ assert isinstance(conn, CrdbConnection)
+ num = conn.info.server_version
+ version = f"{num // 10000}.{num % 10000 // 100}.{num % 100}"
+ else:
+ raise NotImplementedError(f"unexpected vendor: {conn.info.vendor}")
+ return ["", f" # Generated from {conn.info.vendor} {version}", ""]
+
+
+def get_py_types(conn: Connection) -> List[str]:
+ # Note: "record" is a pseudotype but still a useful one to have.
+ # "pg_lsn" is a documented public type and useful in streaming replication
+ lines = []
+ for (typname, oid, typarray, regtype, typdelim) in conn.execute(
+ """
+select typname, oid, typarray,
+ -- CRDB might have quotes in the regtype representation
+ replace(typname::regtype::text, '''', '') as regtype,
+ typdelim
+from pg_type t
+where
+ oid < 10000
+ and oid != '"char"'::regtype
+ and (typtype = 'b' or typname = 'record')
+ and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
+order by typname
+"""
+ ):
+ # Weird legacy type in postgres catalog
+ if typname == "char":
+ typname = regtype = '"char"'
+
+ # https://github.com/cockroachdb/cockroach/issues/81645
+ if typname == "int4" and conn.info.vendor == "CockroachDB":
+ regtype = typname
+
+ params = [f"{typname!r}, {oid}, {typarray}"]
+ if regtype != typname:
+ params.append(f"regtype={regtype!r}")
+ if typdelim != ",":
+ params.append(f"delimiter={typdelim!r}")
+ lines.append(f"TypeInfo({','.join(params)}),")
+
+ return lines
+
+
+def get_py_ranges(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid, typarray, rngsubtype) in conn.execute(
+ """
+select typname, oid, typarray, rngsubtype
+from
+ pg_type t
+ join pg_range r on t.oid = rngtypid
+where
+ oid < 10000
+ and typtype = 'r'
+order by typname
+"""
+ ):
+ params = [f"{typname!r}, {oid}, {typarray}, subtype_oid={rngsubtype}"]
+ lines.append(f"RangeInfo({','.join(params)}),")
+
+ return lines
+
+
+def get_py_multiranges(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid, typarray, rngtypid, rngsubtype) in conn.execute(
+ """
+select typname, oid, typarray, rngtypid, rngsubtype
+from
+ pg_type t
+ join pg_range r on t.oid = rngmultitypid
+where
+ oid < 10000
+ and typtype = 'm'
+order by typname
+"""
+ ):
+ params = [
+ f"{typname!r}, {oid}, {typarray},"
+ f" range_oid={rngtypid}, subtype_oid={rngsubtype}"
+ ]
+ lines.append(f"MultirangeInfo({','.join(params)}),")
+
+ return lines
+
+
+def get_cython_oids(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid) in conn.execute(
+ """
+select typname, oid
+from pg_type
+where
+ oid < 10000
+ and (typtype = any('{b,r,m}') or typname = 'record')
+ and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
+order by typname
+"""
+ ):
+ const_name = typname.upper() + "_OID"
+ lines.append(f" {const_name} = {oid}")
+
+ return lines
+
+
+def update_file(fn: Path, new: List[str]) -> None:
+ with fn.open("r") as f:
+ lines = f.read().splitlines()
+ istart, iend = [
+ i
+ for i, line in enumerate(lines)
+ if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
+ ]
+ lines[istart + 1 : iend] = new
+
+ with fn.open("w") as f:
+ f.write("\n".join(lines))
+ f.write("\n")
+
+
+def parse_cmdline() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
+ )
+ parser.add_argument("dsn", help="where to connect to")
+ opt = parser.parse_args()
+ return opt
+
+
+if __name__ == "__main__":
+ main()