diff options
304 files changed, 72932 insertions, 0 deletions
diff --git a/.codespellrc b/.codespellrc new file mode 100644 index 0000000..33ec5a6 --- /dev/null +++ b/.codespellrc @@ -0,0 +1,3 @@ +[codespell] +ignore-words-list = alot,ans,ba,fo,te +skip = docs/_build,.tox,.mypy_cache,.venv,pq.c @@ -0,0 +1,7 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 +extend-exclude = .venv build +per-file-ignores = + # Autogenerated section + psycopg/psycopg/errors.py: E125, E128, E302 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..b648a1e --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,4 @@ +github: + - dvarrazzo +custom: + - "https://www.paypal.me/dvarrazzo" diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..1dd1e94 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,22 @@ +name: Build documentation + +on: + push: + branches: + # This should match the DOC3_BRANCH value in the psycopg-website Makefile + - master + +concurrency: + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - name: Trigger docs build + uses: peter-evans/repository-dispatch@v1 + with: + repository: psycopg/psycopg-website + event-type: psycopg3-commit + token: ${{ secrets.ACCESS_TOKEN }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..4527551 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,48 @@ +name: Lint + +on: + push: + # This should disable running the workflow on tags, according to the + # on.<push|pull_request>.<branches|tags> GitHub Actions docs. + branches: + - "*" + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: install packages to tests + run: pip install ./psycopg[dev,test] codespell + + - name: Run black + run: black --check --diff . + + - name: Run flake8 + run: flake8 + + - name: Run mypy + run: mypy + + - name: Check spelling + run: codespell + + - name: Install requirements to generate docs + run: sudo apt-get install -y libgeos-dev + + - name: Install Python packages to generate docs + run: pip install ./psycopg[docs] ./psycopg_pool + + - name: Check documentation + run: sphinx-build -W -T -b html docs docs/_build/html diff --git a/.github/workflows/packages-pool.yml b/.github/workflows/packages-pool.yml new file mode 100644 index 0000000..e9624e7 --- /dev/null +++ b/.github/workflows/packages-pool.yml @@ -0,0 +1,66 @@ +name: Build pool packages + +on: + workflow_dispatch: + schedule: + - cron: '28 6 * * sun' + +jobs: + + sdist: + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + include: + - {package: psycopg_pool, format: sdist, impl: python} + - {package: psycopg_pool, format: wheel, impl: python} + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: 3.9 + + - name: Create the sdist packages + run: |- + python ${{ matrix.package }}/setup.py sdist -d `pwd`/dist/ + if: ${{ matrix.format == 'sdist' }} + + - name: Create the wheel packages + run: |- + pip install wheel + python ${{ matrix.package }}/setup.py bdist_wheel -d `pwd`/dist/ + if: ${{ matrix.format == 'wheel' }} + + - name: Install the Python pool package and test requirements + run: |- + pip install dist/* + pip install ./psycopg[test] + + - name: Test the sdist package + run: pytest -m 'not slow and not flakey' --color yes + env: + PSYCOPG_IMPL: ${{ matrix.impl }} + PSYCOPG_TEST_DSN: "host=127.0.0.1 user=postgres" + PGPASSWORD: password + + - uses: actions/upload-artifact@v3 + with: + path: ./dist/* + + services: + postgresql: + image: postgres:14 + env: + POSTGRES_PASSWORD: password + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 diff --git a/.github/workflows/packages.yml b/.github/workflows/packages.yml new file mode 100644 index 0000000..18a2817 --- /dev/null +++ b/.github/workflows/packages.yml @@ -0,0 +1,256 @@ +name: Build packages + +on: + workflow_dispatch: + schedule: + - cron: '28 7 * * sun' + +jobs: + + sdist: # {{{ + runs-on: ubuntu-latest + if: true + + strategy: + fail-fast: false + matrix: + include: + - {package: psycopg, format: sdist, impl: python} + - {package: psycopg, format: wheel, impl: python} + - {package: psycopg_c, format: sdist, impl: c} + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: 3.9 + + - name: Create the sdist packages + run: |- + python ${{ matrix.package }}/setup.py sdist -d `pwd`/dist/ + if: ${{ matrix.format == 'sdist' }} + + - name: Create the wheel packages + run: |- + pip install wheel + python ${{ matrix.package }}/setup.py bdist_wheel -d `pwd`/dist/ + if: ${{ matrix.format == 'wheel' }} + + - name: Install the Python package and test requirements + run: |- + pip install `ls dist/*`[test] + pip install ./psycopg_pool + if: ${{ matrix.package == 'psycopg' }} + + - name: Install the C package and test requirements + run: |- + pip install dist/* + pip install ./psycopg[test] + pip install ./psycopg_pool + if: ${{ matrix.package == 'psycopg_c' }} + + - name: Test the sdist package + run: pytest -m 'not slow and not flakey' --color yes + env: + PSYCOPG_IMPL: ${{ matrix.impl }} + PSYCOPG_TEST_DSN: "host=127.0.0.1 user=postgres" + PGPASSWORD: password + + - uses: actions/upload-artifact@v3 + with: + path: ./dist/* + + services: + postgresql: + image: postgres:14 + env: + POSTGRES_PASSWORD: password + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + + # }}} + + linux: # {{{ + runs-on: ubuntu-latest + if: true + + env: + LIBPQ_VERSION: "15.1" + OPENSSL_VERSION: "1.1.1s" + + strategy: + fail-fast: false + matrix: + arch: [x86_64, i686, ppc64le, aarch64] + pyver: [cp37, cp38, cp39, cp310, cp311] + platform: [manylinux, musllinux] + + steps: + - uses: actions/checkout@v3 + + - name: Set up QEMU for multi-arch build + # Check https://github.com/docker/setup-qemu-action for newer versions. + uses: docker/setup-qemu-action@v2 + with: + # Note: 6.2.0 is buggy: make sure to avoid it. + # See https://github.com/pypa/cibuildwheel/issues/1250 + image: tonistiigi/binfmt:qemu-v7.0.0 + + - name: Cache libpq build + uses: actions/cache@v3 + with: + path: /tmp/libpq.build + key: libpq-${{ env.LIBPQ_VERSION }}-${{ matrix.platform }}-${{ matrix.arch }}-2 + + - name: Create the binary package source tree + run: python3 ./tools/build/copy_to_binary.py + + - name: Build wheels + uses: pypa/cibuildwheel@v2.9.0 + with: + package-dir: psycopg_binary + env: + CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014 + CIBW_MANYLINUX_I686_IMAGE: manylinux2014 + CIBW_MANYLINUX_AARCH64_IMAGE: manylinux2014 + CIBW_MANYLINUX_PPC64LE_IMAGE: manylinux2014 + CIBW_BUILD: ${{matrix.pyver}}-${{matrix.platform}}_${{matrix.arch}} + CIBW_ARCHS_LINUX: auto aarch64 ppc64le + CIBW_BEFORE_ALL_LINUX: ./tools/build/wheel_linux_before_all.sh + CIBW_REPAIR_WHEEL_COMMAND: >- + ./tools/build/strip_wheel.sh {wheel} + && auditwheel repair -w {dest_dir} {wheel} + CIBW_TEST_REQUIRES: ./psycopg[test] ./psycopg_pool + CIBW_TEST_COMMAND: >- + pytest {project}/tests -m 'not slow and not flakey' --color yes + CIBW_ENVIRONMENT_PASS_LINUX: LIBPQ_VERSION OPENSSL_VERSION + CIBW_ENVIRONMENT: >- + PSYCOPG_IMPL=binary + PSYCOPG_TEST_DSN='host=172.17.0.1 user=postgres' + PGPASSWORD=password + LIBPQ_BUILD_PREFIX=/host/tmp/libpq.build + PATH="$LIBPQ_BUILD_PREFIX/bin:$PATH" + LD_LIBRARY_PATH="$LIBPQ_BUILD_PREFIX/lib" + PSYCOPG_TEST_WANT_LIBPQ_BUILD=${{ env.LIBPQ_VERSION }} + PSYCOPG_TEST_WANT_LIBPQ_IMPORT=${{ env.LIBPQ_VERSION }} + + - uses: actions/upload-artifact@v3 + with: + path: ./wheelhouse/*.whl + + services: + postgresql: + image: postgres:14 + env: + POSTGRES_PASSWORD: password + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + + # }}} + + macos: # {{{ + runs-on: macos-latest + if: true + + strategy: + fail-fast: false + matrix: + # These archs require an Apple M1 runner: [arm64, universal2] + arch: [x86_64] + pyver: [cp37, cp38, cp39, cp310, cp311] + + steps: + - uses: actions/checkout@v3 + + - name: Create the binary package source tree + run: python3 ./tools/build/copy_to_binary.py + + - name: Build wheels + uses: pypa/cibuildwheel@v2.9.0 + with: + package-dir: psycopg_binary + env: + CIBW_BUILD: ${{matrix.pyver}}-macosx_${{matrix.arch}} + CIBW_ARCHS_MACOS: x86_64 + CIBW_BEFORE_ALL_MACOS: ./tools/build/wheel_macos_before_all.sh + CIBW_TEST_REQUIRES: ./psycopg[test] ./psycopg_pool + CIBW_TEST_COMMAND: >- + pytest {project}/tests -m 'not slow and not flakey' --color yes + CIBW_ENVIRONMENT: >- + PSYCOPG_IMPL=binary + PSYCOPG_TEST_DSN='dbname=postgres' + PSYCOPG_TEST_WANT_LIBPQ_BUILD=">= 14" + PSYCOPG_TEST_WANT_LIBPQ_IMPORT=">= 14" + + - uses: actions/upload-artifact@v3 + with: + path: ./wheelhouse/*.whl + + + # }}} + + windows: # {{{ + runs-on: windows-latest + if: true + + strategy: + fail-fast: false + matrix: + # Might want to add win32, untested at the moment. + arch: [win_amd64] + pyver: [cp37, cp38, cp39, cp310, cp311] + + steps: + - uses: actions/checkout@v3 + + - name: Start PostgreSQL service for test + run: | + $PgSvc = Get-Service "postgresql*" + Set-Service $PgSvc.Name -StartupType manual + $PgSvc.Start() + + - name: Create the binary package source tree + run: python3 ./tools/build/copy_to_binary.py + + - name: Build wheels + uses: pypa/cibuildwheel@v2.9.0 + with: + package-dir: psycopg_binary + env: + CIBW_BUILD: ${{matrix.pyver}}-${{matrix.arch}} + CIBW_ARCHS_WINDOWS: AMD64 x86 + CIBW_BEFORE_BUILD_WINDOWS: '.\tools\build\wheel_win32_before_build.bat' + CIBW_REPAIR_WHEEL_COMMAND_WINDOWS: >- + delvewheel repair -w {dest_dir} + --no-mangle "libiconv-2.dll;libwinpthread-1.dll" {wheel} + CIBW_TEST_REQUIRES: ./psycopg[test] ./psycopg_pool + CIBW_TEST_COMMAND: >- + pytest {project}/tests -m "not slow and not flakey" --color yes + CIBW_ENVIRONMENT_WINDOWS: >- + PSYCOPG_IMPL=binary + PATH="C:\\Program Files\\PostgreSQL\\14\\bin;$PATH" + PSYCOPG_TEST_DSN="host=127.0.0.1 user=postgres" + PSYCOPG_TEST_WANT_LIBPQ_BUILD=">= 14" + PSYCOPG_TEST_WANT_LIBPQ_IMPORT=">= 14" + + - uses: actions/upload-artifact@v3 + with: + path: ./wheelhouse/*.whl + + + # }}} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..9f6a7f5 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,314 @@ +name: Tests + +on: + push: + # This should disable running the workflow on tags, according to the + # on.<push|pull_request>.<branches|tags> GitHub Actions docs. + branches: + - "*" + pull_request: + schedule: + - cron: '48 6 * * *' + +concurrency: + # Cancel older requests of the same workflow in the same branch. + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true + +jobs: + + linux: # {{{ + runs-on: ubuntu-latest + if: true + + strategy: + fail-fast: false + matrix: + include: + # Test different combinations of Python, Postgres, libpq. + - {impl: python, python: "3.7", postgres: "postgres:10", libpq: newest} + - {impl: python, python: "3.8", postgres: "postgres:12"} + - {impl: python, python: "3.9", postgres: "postgres:13"} + - {impl: python, python: "3.10", postgres: "postgres:14"} + - {impl: python, python: "3.11", postgres: "postgres:15", libpq: oldest} + + - {impl: c, python: "3.7", postgres: "postgres:15", libpq: newest} + - {impl: c, python: "3.8", postgres: "postgres:13"} + - {impl: c, python: "3.9", postgres: "postgres:14"} + - {impl: c, python: "3.10", postgres: "postgres:13", libpq: oldest} + - {impl: c, python: "3.11", postgres: "postgres:10", libpq: newest} + + - {impl: python, python: "3.9", ext: dns, postgres: "postgres:14"} + - {impl: python, python: "3.9", ext: postgis, postgres: "postgis/postgis"} + + env: + PSYCOPG_IMPL: ${{ matrix.impl }} + DEPS: ./psycopg[test] ./psycopg_pool + PSYCOPG_TEST_DSN: "host=127.0.0.1 user=postgres" + PGPASSWORD: password + MARKERS: "" + + # Enable to run tests using the minimum version of dependencies. + # PIP_CONSTRAINT: ${{ github.workspace }}/tests/constraints.txt + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + + - name: Install the newest libpq version available + if: ${{ matrix.libpq == 'newest' }} + run: | + set -x + + curl -sL https://www.postgresql.org/media/keys/ACCC4CF8.asc \ + | gpg --dearmor \ + | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg > /dev/null + + # NOTE: in order to test with a preview release, add its number to + # the deb entry. For instance, to test on preview Postgres 16, use: + # "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main 16" + rel=$(lsb_release -c -s) + echo "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main" \ + | sudo tee -a /etc/apt/sources.list.d/pgdg.list > /dev/null + sudo apt-get -qq update + + pqver=$(apt-cache show libpq5 | grep ^Version: | head -1 \ + | awk '{print $2}') + sudo apt-get -qq -y install "libpq-dev=${pqver}" "libpq5=${pqver}" + + - name: Install the oldest libpq version available + if: ${{ matrix.libpq == 'oldest' }} + run: | + set -x + pqver=$(apt-cache show libpq5 | grep ^Version: | tail -1 \ + | awk '{print $2}') + sudo apt-get -qq -y --allow-downgrades install \ + "libpq-dev=${pqver}" "libpq5=${pqver}" + + - if: ${{ matrix.ext == 'dns' }} + run: | + echo "DEPS=$DEPS dnspython" >> $GITHUB_ENV + echo "MARKERS=$MARKERS dns" >> $GITHUB_ENV + + - if: ${{ matrix.ext == 'postgis' }} + run: | + echo "DEPS=$DEPS shapely" >> $GITHUB_ENV + echo "MARKERS=$MARKERS postgis" >> $GITHUB_ENV + + - if: ${{ matrix.impl == 'c' }} + run: | + echo "DEPS=$DEPS ./psycopg_c" >> $GITHUB_ENV + + - name: Install Python dependencies + run: pip install $DEPS + + - name: Run tests + run: ./tools/build/ci_test.sh + + services: + postgresql: + image: ${{ matrix.postgres }} + env: + POSTGRES_PASSWORD: password + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + + # }}} + + macos: # {{{ + runs-on: macos-latest + if: true + + strategy: + fail-fast: false + matrix: + include: + - {impl: python, python: "3.7"} + - {impl: python, python: "3.8"} + - {impl: python, python: "3.9"} + - {impl: python, python: "3.10"} + - {impl: python, python: "3.11"} + - {impl: c, python: "3.7"} + - {impl: c, python: "3.8"} + - {impl: c, python: "3.9"} + - {impl: c, python: "3.10"} + - {impl: c, python: "3.11"} + + env: + PSYCOPG_IMPL: ${{ matrix.impl }} + DEPS: ./psycopg[test] ./psycopg_pool + PSYCOPG_TEST_DSN: "host=127.0.0.1 user=runner dbname=postgres" + # MacOS on GitHub Actions seems particularly slow. + # Don't run timing-based tests as they regularly fail. + # pproxy-based tests fail too, with the proxy not coming up in 2s. + NOT_MARKERS: "timing proxy mypy" + # PIP_CONSTRAINT: ${{ github.workspace }}/tests/constraints.txt + + steps: + - uses: actions/checkout@v3 + + - name: Install PostgreSQL on the runner + run: brew install postgresql@14 + + - name: Start PostgreSQL service for test + run: brew services start postgresql + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + + - if: ${{ matrix.impl == 'c' }} + # skip tests failing on importing psycopg_c.pq on subprocess + # they only fail on Travis, work ok locally under tox too. + # TODO: check the same on GitHub Actions + run: | + echo "DEPS=$DEPS ./psycopg_c" >> $GITHUB_ENV + + - name: Install Python dependencies + run: pip install $DEPS + + - name: Run tests + run: ./tools/build/ci_test.sh + + + # }}} + + windows: # {{{ + runs-on: windows-latest + if: true + + strategy: + fail-fast: false + matrix: + include: + - {impl: python, python: "3.7"} + - {impl: python, python: "3.8"} + - {impl: python, python: "3.9"} + - {impl: python, python: "3.10"} + - {impl: python, python: "3.11"} + - {impl: c, python: "3.7"} + - {impl: c, python: "3.8"} + - {impl: c, python: "3.9"} + - {impl: c, python: "3.10"} + - {impl: c, python: "3.11"} + + env: + PSYCOPG_IMPL: ${{ matrix.impl }} + DEPS: ./psycopg[test] ./psycopg_pool + PSYCOPG_TEST_DSN: "host=127.0.0.1 dbname=postgres" + # On windows pproxy doesn't seem very happy. Also a few timing test fail. + NOT_MARKERS: "timing proxy mypy" + # PIP_CONSTRAINT: ${{ github.workspace }}/tests/constraints.txt + + steps: + - uses: actions/checkout@v3 + + - name: Start PostgreSQL service for test + run: | + $PgSvc = Get-Service "postgresql*" + Set-Service $PgSvc.Name -StartupType manual + $PgSvc.Start() + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + + # Build a wheel package of the C extensions. + # If the wheel is not delocated, import fails with some dll not found + # (but it won't tell which one). + - name: Build the C wheel + if: ${{ matrix.impl == 'c' }} + run: | + pip install delvewheel wheel + $env:Path = "C:\Program Files\PostgreSQL\14\bin\;$env:Path" + python ./psycopg_c/setup.py bdist_wheel + &"delvewheel" repair ` + --no-mangle "libiconv-2.dll;libwinpthread-1.dll" ` + @(Get-ChildItem psycopg_c\dist\*.whl) + &"pip" install @(Get-ChildItem wheelhouse\*.whl) + + - name: Run tests + run: | + pip install $DEPS + ./tools/build/ci_test.sh + shell: bash + + + # }}} + + crdb: # {{{ + runs-on: ubuntu-latest + if: true + + strategy: + fail-fast: false + matrix: + include: + - {impl: c, crdb: "latest-v22.1", python: "3.10", libpq: newest} + - {impl: python, crdb: "latest-v22.2", python: "3.11"} + env: + PSYCOPG_IMPL: ${{ matrix.impl }} + DEPS: ./psycopg[test] ./psycopg_pool + PSYCOPG_TEST_DSN: "host=127.0.0.1 port=26257 user=root dbname=defaultdb" + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + + - name: Run CockroachDB + # Note: this would love to be a service, but I don't see a way to pass + # the args to the docker run command line. + run: | + docker pull cockroachdb/cockroach:${{ matrix.crdb }} + docker run --rm -d --name crdb -p 26257:26257 \ + cockroachdb/cockroach:${{ matrix.crdb }} start-single-node --insecure + + - name: Install the newest libpq version available + if: ${{ matrix.libpq == 'newest' }} + run: | + set -x + + curl -sL https://www.postgresql.org/media/keys/ACCC4CF8.asc \ + | gpg --dearmor \ + | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg > /dev/null + + # NOTE: in order to test with a preview release, add its number to + # the deb entry. For instance, to test on preview Postgres 16, use: + # "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main 16" + rel=$(lsb_release -c -s) + echo "deb http://apt.postgresql.org/pub/repos/apt ${rel}-pgdg main" \ + | sudo tee -a /etc/apt/sources.list.d/pgdg.list > /dev/null + sudo apt-get -qq update + + pqver=$(apt-cache show libpq5 | grep ^Version: | head -1 \ + | awk '{print $2}') + sudo apt-get -qq -y install "libpq-dev=${pqver}" "libpq5=${pqver}" + + - if: ${{ matrix.impl == 'c' }} + run: | + echo "DEPS=$DEPS ./psycopg_c" >> $GITHUB_ENV + + - name: Install Python dependencies + run: pip install $DEPS + + - name: Run tests + run: ./tools/build/ci_test.sh + + - name: Stop CockroachDB + run: docker kill crdb + + + # }}} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d8c58c --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +*.egg-info/ +.tox +*.pstats +*.swp +.mypy_cache +__pycache__/ +/docs/_build/ +*.html +/psycopg_binary/ +.vscode +.venv +.coverage +htmlcov + +.eggs/ +dist/ +wheelhouse/ +# Spelling these explicitly because we have /scripts/build/ to not ignore +# but I still want 'ag' to avoid looking here. +/build/ +/psycopg/build/ +/psycopg_c/build/ +/psycopg_pool/build/ diff --git a/.yamllint.yaml b/.yamllint.yaml new file mode 100644 index 0000000..bcfb4d2 --- /dev/null +++ b/.yamllint.yaml @@ -0,0 +1,8 @@ +extends: default + +rules: + truthy: + check-keys: false + document-start: disable + line-length: + max: 85 diff --git a/BACKERS.yaml b/BACKERS.yaml new file mode 100644 index 0000000..b8bf830 --- /dev/null +++ b/BACKERS.yaml @@ -0,0 +1,140 @@ +--- +# You can find our sponsors at https://www.psycopg.org/sponsors/ Thank you! + +- username: postgrespro + tier: top + avatar: https://avatars.githubusercontent.com/u/12005770?v=4 + name: Postgres Professional + website: https://postgrespro.com/ + +- username: commandprompt + by: jdatcmd + tier: top + avatar: https://avatars.githubusercontent.com/u/339156?v=4 + name: Command Prompt, Inc. + website: https://www.commandprompt.com + +- username: bitdotioinc + tier: top + avatar: https://avatars.githubusercontent.com/u/56135630?v=4 + name: bit.io + website: https://bit.io/?utm_campaign=sponsorship&utm_source=psycopg2&utm_medium=web + keep_website: true + + +- username: yougov + tier: mid + avatar: https://avatars.githubusercontent.com/u/378494?v=4 + name: YouGov + website: https://www.yougov.com + +- username: phenopolis + by: pontikos + tier: mid + avatar: https://avatars.githubusercontent.com/u/20042742?v=4 + name: Phenopolis + website: http://www.phenopolis.co.uk + +- username: MaterializeInc + tier: mid + avatar: https://avatars.githubusercontent.com/u/47674186?v=4 + name: Materialize, Inc. + website: http://materialize.com + +- username: getsentry + tier: mid + avatar: https://avatars.githubusercontent.com/u/1396951?v=4 + name: Sentry + website: https://sentry.io + +- username: 20tab + tier: mid + avatar: https://avatars.githubusercontent.com/u/1843159?v=4 + name: 20tab srl + website: http://www.20tab.com + +- username: genropy + by: gporcari + tier: mid + avatar: https://avatars.githubusercontent.com/u/7373189?v=4 + name: genropy + website: http://www.genropy.org + +- username: svennek + tier: mid + avatar: https://avatars.githubusercontent.com/u/37837?v=4 + name: Svenne Krap + website: http://www.svenne.dk + +- username: mailupinc + tier: mid + avatar: https://avatars.githubusercontent.com/u/72260631?v=4 + name: BEE + website: https://beefree.io + + +- username: taifu + avatar: https://avatars.githubusercontent.com/u/115712?v=4 + name: Marco Beri + website: http:/beri.it + +- username: la-mar + avatar: https://avatars.githubusercontent.com/u/16618300?v=4 + name: Brock Friedrich + +- username: xarg + avatar: https://avatars.githubusercontent.com/u/94721?v=4 + name: Alex Plugaru + website: https://plugaru.org + +- username: dalibo + avatar: https://avatars.githubusercontent.com/u/182275?v=4 + name: Dalibo + website: http://www.dalibo.com + +- username: rafmagns-skepa-dreag + avatar: https://avatars.githubusercontent.com/u/7447491?v=4 + name: Richard H + +- username: rustprooflabs + avatar: https://avatars.githubusercontent.com/u/3085224?v=4 + name: Ryan Lambert + website: https://www.rustprooflabs.com + +- username: logilab + avatar: https://avatars.githubusercontent.com/u/446566?v=4 + name: Logilab + website: http://www.logilab.org + +- username: asqui + avatar: https://avatars.githubusercontent.com/u/174182?v=4 + name: Daniel Fortunov + +- username: iqbalabd + avatar: https://avatars.githubusercontent.com/u/14254614?v=4 + name: Iqbal Abdullah + website: https://info.xoxzo.com/ + +- username: healthchecks + avatar: https://avatars.githubusercontent.com/u/13053880?v=4 + name: Healthchecks + website: https://healthchecks.io + +- username: c-rindi + avatar: https://avatars.githubusercontent.com/u/7826876?v=4 + name: C~+ + +- username: Intevation + by: bernhardreiter + avatar: https://avatars.githubusercontent.com/u/2050405?v=4 + name: Intevation + website: https://www.intevation.de/ + +- username: abegerho + avatar: https://avatars.githubusercontent.com/u/5734243?v=4 + name: Abhishek Begerhotta + +- username: ferpection + avatar: https://avatars.githubusercontent.com/u/6997008?v=4 + name: Ferpection + website: https://ferpection.com diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..50934d8 --- /dev/null +++ b/README.rst @@ -0,0 +1,75 @@ +Psycopg 3 -- PostgreSQL database adapter for Python +=================================================== + +Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python. + + +Installation +------------ + +Quick version:: + + pip install --upgrade pip # upgrade pip to at least 20.3 + pip install "psycopg[binary,pool]" # install binary dependencies + +For further information about installation please check `the documentation`__. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + + +Hacking +------- + +In order to work on the Psycopg source code you need to have the ``libpq`` +PostgreSQL client library installed in the system. For instance, on Debian +systems, you can obtain it by running:: + + sudo apt install libpq5 + +After which you can clone this repository:: + + git clone https://github.com/psycopg/psycopg.git + cd psycopg + +Please note that the repository contains the source code of several Python +packages: that's why you don't see a ``setup.py`` here. The packages may have +different requirements: + +- The ``psycopg`` directory contains the pure python implementation of + ``psycopg``. The package has only a runtime dependency on the ``libpq``, the + PostgreSQL client library, which should be installed in your system. + +- The ``psycopg_c`` directory contains an optimization module written in + C/Cython. In order to build it you will need a few development tools: please + look at `Local installation`__ in the docs for the details. + + .. __: https://www.psycopg.org/psycopg3/docs/basic/install.html#local-installation + +- The ``psycopg_pool`` directory contains the `connection pools`__ + implementations. This is kept as a separate package to allow a different + release cycle. + + .. __: https://www.psycopg.org/psycopg3/docs/advanced/pool.html + +You can create a local virtualenv and install there the packages `in +development mode`__, together with their development and testing +requirements:: + + python -m venv .venv + source .venv/bin/activate + pip install -e "./psycopg[dev,test]" # for the base Python package + pip install -e ./psycopg_pool # for the connection pool + pip install ./psycopg_c # for the C speedup module + +.. __: https://pip.pypa.io/en/stable/reference/pip_install/#install-editable + +Please add ``--config-settings editable_mode=strict`` to the ``pip install +-e`` above if you experience `editable mode broken`__. + +.. __: https://github.com/pypa/setuptools/issues/3557 + +Now hack away! You can run the tests using:: + + psql -c 'create database psycopg_test' + export PSYCOPG_TEST_DSN="dbname=psycopg_test" + pytest diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..e86cbd4 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,30 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build +PYTHON ?= python3 + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) || true + +serve: + PSYCOPG_IMPL=python sphinx-autobuild . _build/html/ + +.PHONY: help serve env Makefile + +env: .venv + +.venv: + $(PYTHON) -m venv .venv + ./.venv/bin/pip install -e "../psycopg[docs]" -e ../psycopg_pool + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/psycopg.css b/docs/_static/psycopg.css new file mode 100644 index 0000000..de9d779 --- /dev/null +++ b/docs/_static/psycopg.css @@ -0,0 +1,11 @@ +/* style rubric in furo (too small IMO) */ +p.rubric { + font-size: 1.2rem; + font-weight: bold; +} + +/* override a silly default */ +table.align-default td, +table.align-default th { + text-align: left; +} diff --git a/docs/_static/psycopg.svg b/docs/_static/psycopg.svg new file mode 100644 index 0000000..0e9ee32 --- /dev/null +++ b/docs/_static/psycopg.svg @@ -0,0 +1 @@ +<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 228 148"><path fill="#ffc836" stroke="#000" stroke-width="7.415493" d="M142.3 67.6c.6-4.7.4-5.4 4-4.6h.8c2.7.2 6.2-.4 8.3-1.3 4.4-2.1 7-5.5 2.7-4.6-10 2-10.7-1.4-10.7-1.4 10.5-15.6 15-35.5 11.1-40.3-10.3-13.3-28.3-7-28.6-6.8h-.1c-2-.4-4.2-.7-6.7-.7-4.5 0-8 1.2-10.5 3.1 0 0-32-13.2-30.6 16.6.4 6.4 9.1 48 19.6 35.4 3.8-4.6 7.5-8.4 7.5-8.4 1.8 1.2 4 1.8 6.3 1.6l.2-.2a7 7 0 000 1.8c-2.6 3-1.8 3.5-7.2 4.7-5.4 1-2.2 3-.2 3.6 2.6.6 8.4 1.5 12.4-4l-.2.6c1.1.9 1 6.1 1.2 9.8.1 3.8.4 7.3 1.1 9.3.8 2 1.7 7.4 8.8 5.9 5.9-1.3 10.4-3.1 10.8-20.1"/><path fill="#ff0" d="M105.4 54.2a2.4 2.4 0 114.8 0c0 1.3-1.1 2.4-2.4 2.4-1.3 0-2.4-1-2.4-2.4z"/><g fill="#336791"><path stroke="#000" stroke-width="7.415493" d="M85.7 80.4c-.6 4.7-.4 5.4-4 4.6H81c-2.7-.2-6.2.4-8.3 1.3-4.4 2.1-7 5.5-2.7 4.6 10-2 10.7 1.4 10.7 1.4-10.5 15.6-15 35.5-11.1 40.3 10.3 13.3 28.3 7 28.6 6.8h.1c2 .4 4.2.7 6.7.7 4.5 0 8-1.2 10.5-3.1 0 0 32 13.2 30.6-16.6-.4-6.4-9.1-48-19.6-35.4-3.8 4.6-7.5 8.4-7.5 8.4a9.7 9.7 0 00-6.3-1.6l-.2.2a7 7 0 000-1.8c2.6-3 1.8-3.5 7.2-4.7 5.4-1 2.2-3 .2-3.6-2.6-.6-8.4-1.5-12.4 4l.2-.6c-1.1-.9-1-6.1-1.2-9.8-.1-3.8-.4-7.3-1.1-9.3-.8-2-1.7-7.4-8.8-5.9-5.9 1.3-10.4 3.1-10.8 20.1"/><path d="M70 91c10-2.1 10.6 1.3 10.6 1.3-10.5 15.6-15 35.5-11.1 40.3 10.3 13.3 28.3 7 28.6 6.8h.1c2 .4 4.2.7 6.7.7 4.5 0 8-1.2 10.5-3.1 0 0 32 13.2 30.6-16.6-.4-6.4-9.1-48-19.6-35.4-3.8 4.6-7.5 8.4-7.5 8.4a9.7 9.7 0 00-6.3-1.6l-.2.2a7 7 0 000-1.8c2.6-3 1.8-3.5 7.2-4.7 5.5-1 2.2-3 .2-3.6-2.6-.6-8.4-1.5-12.4 4l.2-.6c-1.1-.9-1.8-5.5-1.7-9.7.1-4.3.2-7.2-.6-9.4s-1.7-7.4-8.8-5.9c-5.9 1.3-9 4.6-9.4 10-.3 4-1 3.4-1 6.9l-.6 1.6c-.6 5.3 0 7-3.7 6.2h-.9c-2.7-.2-6.2.4-8.3 1.3-4.4 2.1-7 5.5-2.7 4.6z"/><g stroke="#ffc836" stroke-linecap="round" stroke-width="15.6"><g stroke-linejoin="round"><path stroke-width="2.477124" d="M107 88c.2-10-.1-19.8-1-22.2-1-2.4-3-7.1-10.2-5.6-5.9 1.3-8 3.7-9 9.1-.7 4-2 15.1-2.2 17.4M115.5 137.2s32 13 30.5-16.7c-.3-6.4-9-48-19.5-35.4-3.8 4.6-7.3 8.2-7.3 8.2M98.1 139.6c1.2-.4-17.8 6.9-28.6-6.9-3.8-4.8.6-24.7 11.2-40.3"/></g><path stroke-linejoin="bevel" stroke-width="2.477124" d="M80.7 92.4S80 89 70 91c-4.4.9-1.7-2.6 2.7-4.6 3.7-1.7 11.8-2.2 12 .2.3 6-4.3 4.2-4 5.7.3 1.3 2.4 2.7 3.7 6 1.2 2.9 16.5 25.2-4.2 21.9-.7.1 5.4 19.6 24.7 20 19.4.3 18.7-23.8 18.7-23.8"/><path stroke-linejoin="round" stroke-width="2.477124" d="M112.4 90.3c2.7-3 1.9-3.5 7.3-4.6 5.4-1.2 2.2-3.2.1-3.7-2.5-.6-8.4-1.5-12.3 4-1.2 1.7 0 4.4 1.6 5.1.8.3 2 .8 3.3-.8z"/><path stroke-linejoin="round" stroke-width="2.477124" d="M112.6 90.4c.2 1.7-.6 3.8-1.5 6.3-1.4 3.7-4.6 7.4-2 19.1 1.8 8.8 14.5 1.8 14.5.7 0-1.2-.5-6 .2-11.7 1-7.3-4.6-13.5-11.2-12.8"/></g></g><g stroke="#ffc836"><path fill="#ffc836" stroke-width=".825708" d="M115.6 116.6c0-.4-.8-1.4-1.8-1.6-1-.1-2 .7-2 1.1 0 .4.7.9 1.8 1 1 .2 2 0 2-.5z"/><path fill="#ffc836" stroke-width=".412854" d="M84 117.5c-.1-.4.7-1.5 1.7-1.7 1-.1 2 .7 2 1.1 0 .4-.7.9-1.8 1s-2 0-2-.4z"/><path fill="#336791" stroke-linecap="round" stroke-linejoin="round" stroke-width="2.477124" d="M80.2 120.3c-.2-3.2.7-5.4.8-8.7.2-5-2.3-10.6 1.4-16.2"/></g><g fill="#ffc836"><path d="M158 57c-10 2.1-10.6-1.3-10.6-1.3 10.5-15.6 15-35.5 11.1-40.3-10.3-13.3-28.3-7-28.6-6.8h-.1c-2-.4-4.2-.7-6.7-.7-4.5 0-8 1.2-10.5 3.1 0 0-32-13.2-30.6 16.6.4 6.4 9.1 48 19.6 35.4 3.8-4.6 7.5-8.5 7.5-8.5 1.8 1.3 4 1.9 6.3 1.7l.2-.2a7 7 0 000 1.8c-2.6 3-1.8 3.5-7.2 4.7-5.5 1-2.3 3-.2 3.6 2.6.6 8.4 1.5 12.4-4l-.2.6c1.1.9 1.8 5.5 1.7 9.7-.1 4.3-.2 7.2.6 9.4s1.7 7.4 8.8 5.9c5.9-1.3 9-4.6 9.4-10 .3-4 1-3.4 1-6.9l.6-1.6c.6-5.3 0-7 3.7-6.2h.9c2.7.2 6.2-.4 8.3-1.3 4.4-2.1 7-5.5 2.7-4.6z"/><path d="M142.3 67.6c.6-4.7.4-5.4 4-4.6h.8c2.7.2 6.2-.4 8.3-1.3 4.4-2.1 7-5.5 2.7-4.6-10 2-10.7-1.4-10.7-1.4 10.5-15.6 15-35.5 11.1-40.3-10.3-13.3-28.3-7-28.6-6.8h-.1c-2-.4-4.2-.7-6.7-.7-4.5 0-8 1.2-10.5 3.1 0 0-32-13.2-30.6 16.6.4 6.4 9.1 48 19.6 35.4 3.8-4.6 7.5-8.4 7.5-8.4 1.8 1.2 4 1.8 6.3 1.6l.2-.2a7 7 0 000 1.8c-2.6 3-1.8 3.5-7.2 4.7-5.4 1-2.2 3-.2 3.6 2.6.6 8.4 1.5 12.4-4l-.2.6c1.1.9 1 6.1 1.2 9.8.1 3.8.4 7.3 1.1 9.3.8 2 1.7 7.4 8.8 5.9 5.9-1.3 10.4-3.1 10.8-20.1"/><g stroke="#336791" stroke-linecap="round" stroke-width="15.6"><g stroke-linejoin="round"><path stroke-width="2.477124" d="M112.5 10.8s-32-13-30.5 16.7c.3 6.4 9 48 19.5 35.4 3.8-4.6 7.3-8.2 7.3-8.2M121 60c-.2 10 .1 19.8 1 22.2 1 2.4 3 7.1 10.2 5.6 5.9-1.3 8-3.7 9-9.2.7-4 2-15 2.1-17.3M129.9 8.4c-1.2.4 17.8-6.9 28.6 6.9 3.8 4.8-.6 24.7-11.2 40.3"/></g><path stroke-linejoin="bevel" stroke-width="2.477124" d="M147.3 55.6S148 59 158 57c4.4-.9 1.7 2.6-2.7 4.6-3.7 1.7-11.8 2.2-12-.2-.3-6 4.3-4.2 4-5.7-.3-1.3-2.4-2.7-3.7-6-1.2-3-16.5-25.2 4.2-21.9.7-.1-5.4-19.6-24.7-20-19.4-.3-18.7 23.8-18.7 23.8"/><path stroke-linejoin="round" stroke-width="2.477124" d="M115.6 57.7c-2.7 3-1.9 3.5-7.3 4.6-5.4 1.2-2.2 3.2-.1 3.7 2.5.6 8.4 1.5 12.3-4 1.2-1.7 0-4.4-1.6-5.1-.8-.3-2-.8-3.3.8z"/><path stroke-linejoin="round" stroke-width="2.477124" d="M115.4 57.6c-.2-1.7.6-3.8 1.5-6.3 1.4-3.7 4.6-7.4 2-19.1-1.8-8.8-14.5-1.9-14.5-.7s.5 6-.2 11.7c-1 7.3 4.6 13.5 11.2 12.8"/></g></g><g stroke="#336791"><path fill="#336791" stroke-width=".825708" d="M112.4 31.4c0 .4.8 1.4 1.8 1.6 1 .1 2-.7 2-1.1 0-.4-.7-.9-1.8-1-1-.2-2 0-2 .5z"/><path fill="#336791" stroke-width=".412854" d="M144 30.5c.1.4-.7 1.5-1.7 1.7-1 .1-2-.7-2-1.1 0-.4.7-.9 1.8-1s2 0 2 .4z"/><path fill="#ffc836" stroke-linecap="round" stroke-linejoin="round" stroke-width="2.477124" d="M147.8 27.7c.2 3.2-.7 5.4-.8 8.7-.2 5 2.3 10.6-1.4 16.2"/></g><path fill="#ffc836" stroke="#336791" stroke-linecap="round" stroke-linejoin="round" stroke-width=".6034019999999999" d="M103.8 51h6.6v6.4h-6.6z" color="#000"/><path fill="#336791" stroke="#ffc836" stroke-linecap="round" stroke-linejoin="round" stroke-width="2.477124" d="M107 88c.2-10-.1-19.8-1-22.2-1-2.4-3-7.1-10.2-5.6-5.9 1.3-8 3.7-9 9.1-.7 4-2 15.1-2.2 17.4"/><path fill="#336791" d="M111.7 82.9h22.1v14.4h-22.1z" color="#000"/><path fill="#ffc836" d="M95.8 56h20.1v9.2H95.8z" color="#000"/><g fill="none"><path stroke="#ffc836" stroke-width="1.5878999999999999" d="M113.7 47.6c-2.2 0-4.2.2-6 .5-5.3 1-6.3 3-6.3 6.5v4.8H114V61H96.7a7.8 7.8 0 00-7.8 6.3A23.4 23.4 0 0089 80c1 3.7 3 6.4 6.7 6.4h4.3v-5.7a8 8 0 017.8-7.8h12.5c3.5 0 6.2-2.9 6.2-6.4V54.6c0-3.4-2.8-5.9-6.2-6.5a39 39 0 00-6.5-.5z"/><path stroke="#336791" stroke-width="2.38185" d="M128 61v5.5a8 8 0 01-7.8 8h-12.5c-3.4 0-6.3 2.9-6.3 6.3v12c0 3.3 3 5.3 6.3 6.3a21 21 0 0012.5 0c3.1-1 6.2-2.8 6.2-6.4V88H114v-1.6h18.8c3.6 0 5-2.6 6.2-6.4 1.3-3.9 1.3-7.6 0-12.7-.9-3.6-2.6-6.3-6.2-6.3z"/><path stroke="#ffc836" stroke-width="2.38185" d="M113.7 47.6c-2.2 0-4.2.2-6 .5-5.3 1-6.3 3-6.3 6.5v4.8H114V61H96.7a7.8 7.8 0 00-7.8 6.3A23.4 23.4 0 0089 80c1 3.7 3 6.4 6.7 6.4h4.3v-5.7a8 8 0 017.8-7.8h12.5c3.5 0 6.2-2.9 6.2-6.4V54.6c0-3.4-2.8-5.9-6.2-6.5a39 39 0 00-6.5-.5z"/><path stroke="#336791" stroke-width="1.5878999999999999" d="M128 61v5.5a8 8 0 01-7.8 8h-12.5c-3.4 0-6.3 2.9-6.3 6.3v12c0 3.3 3 5.3 6.3 6.3a21 21 0 0012.5 0c3.1-1 6.2-2.8 6.2-6.4V88H114v-1.6h18.8c3.6 0 5-2.6 6.2-6.4 1.3-3.9 1.3-7.6 0-12.7-.9-3.6-2.6-6.3-6.2-6.3z"/></g><g><path fill="#336791" d="M113.7 47.6c-2.2 0-4.2.2-6 .5-5.3 1-6.3 3-6.3 6.5v4.8H114V61H96.7a7.8 7.8 0 00-7.8 6.3A23.4 23.4 0 0089 80c1 3.7 3 6.4 6.7 6.4h4.3v-5.7a8 8 0 017.8-7.8h12.5c3.5 0 6.2-2.9 6.2-6.4V54.6c0-3.4-2.8-5.9-6.2-6.5a39 39 0 00-6.5-.5zm-6.8 3.9c1.3 0 2.3 1 2.3 2.4 0 1.3-1 2.3-2.3 2.3-1.3 0-2.3-1-2.3-2.3 0-1.4 1-2.4 2.3-2.4z"/><path fill="#ffc836" d="M128 61v5.5a8 8 0 01-7.8 8h-12.5c-3.4 0-6.3 2.9-6.3 6.3v12c0 3.3 3 5.3 6.3 6.3a21 21 0 0012.5 0c3.1-1 6.2-2.8 6.2-6.4V88H114v-1.6h18.8c3.6 0 5-2.6 6.2-6.4 1.3-3.9 1.3-7.6 0-12.7-.9-3.6-2.6-6.3-6.2-6.3zM121 91c1.3 0 2.3 1.1 2.3 2.4 0 1.3-1 2.4-2.3 2.4-1.3 0-2.4-1-2.4-2.4 0-1.3 1-2.4 2.4-2.4z" color="#000"/><path fill="#336791" d="M127.2 59.8h.7v.6h-.7z" color="#000"/></g></svg>
\ No newline at end of file diff --git a/docs/_templates/.keep b/docs/_templates/.keep new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/docs/_templates/.keep diff --git a/docs/advanced/adapt.rst b/docs/advanced/adapt.rst new file mode 100644 index 0000000..4323b07 --- /dev/null +++ b/docs/advanced/adapt.rst @@ -0,0 +1,269 @@ +.. currentmodule:: psycopg.adapt + +.. _adaptation: + +Data adaptation configuration +============================= + +The adaptation system is at the core of Psycopg and allows to customise the +way Python objects are converted to PostgreSQL when a query is performed and +how PostgreSQL values are converted to Python objects when query results are +returned. + +.. note:: + For a high-level view of the conversion of types between Python and + PostgreSQL please look at :ref:`query-parameters`. Using the objects + described in this page is useful if you intend to *customise* the + adaptation rules. + +- Adaptation configuration is performed by changing the + `~psycopg.abc.AdaptContext.adapters` object of objects implementing the + `~psycopg.abc.AdaptContext` protocol, for instance `~psycopg.Connection` + or `~psycopg.Cursor`. + +- Every context object derived from another context inherits its adapters + mapping: cursors created from a connection inherit the connection's + configuration. + + By default, connections obtain an adapters map from the global map + exposed as `psycopg.adapters`: changing the content of this object will + affect every connection created afterwards. You may specify a different + template adapters map using the `!context` parameter on + `~psycopg.Connection.connect()`. + + .. image:: ../pictures/adapt.svg + :align: center + +- The `!adapters` attributes are `AdaptersMap` instances, and contain the + mapping from Python types and `~psycopg.abc.Dumper` classes, and from + PostgreSQL OIDs to `~psycopg.abc.Loader` classes. Changing this mapping + (e.g. writing and registering your own adapters, or using a different + configuration of builtin adapters) affects how types are converted between + Python and PostgreSQL. + + - Dumpers (objects implementing the `~psycopg.abc.Dumper` protocol) are + the objects used to perform the conversion from a Python object to a bytes + sequence in a format understood by PostgreSQL. The string returned + *shouldn't be quoted*: the value will be passed to the database using + functions such as :pq:`PQexecParams()` so quoting and quotes escaping is + not necessary. The dumper usually also suggests to the server what type to + use, via its `~psycopg.abc.Dumper.oid` attribute. + + - Loaders (objects implementing the `~psycopg.abc.Loader` protocol) are + the objects used to perform the opposite operation: reading a bytes + sequence from PostgreSQL and creating a Python object out of it. + + - Dumpers and loaders are instantiated on demand by a `~Transformer` object + when a query is executed. + +.. note:: + Changing adapters in a context only affects that context and its children + objects created *afterwards*; the objects already created are not + affected. For instance, changing the global context will only change newly + created connections, not the ones already existing. + + +.. _adapt-example-xml: + +Writing a custom adapter: XML +----------------------------- + +Psycopg doesn't provide adapters for the XML data type, because there are just +too many ways of handling XML in Python. Creating a loader to parse the +`PostgreSQL xml type`__ to `~xml.etree.ElementTree` is very simple, using the +`psycopg.adapt.Loader` base class and implementing the +`~psycopg.abc.Loader.load()` method: + +.. __: https://www.postgresql.org/docs/current/datatype-xml.html + +.. code:: python + + >>> import xml.etree.ElementTree as ET + >>> from psycopg.adapt import Loader + + >>> # Create a class implementing the `load()` method. + >>> class XmlLoader(Loader): + ... def load(self, data): + ... return ET.fromstring(data) + + >>> # Register the loader on the adapters of a context. + >>> conn.adapters.register_loader("xml", XmlLoader) + + >>> # Now just query the database returning XML data. + >>> cur = conn.execute( + ... """select XMLPARSE (DOCUMENT '<?xml version="1.0"?> + ... <book><title>Manual</title><chapter>...</chapter></book>') + ... """) + + >>> elem = cur.fetchone()[0] + >>> elem + <Element 'book' at 0x7ffb55142ef0> + +The opposite operation, converting Python objects to PostgreSQL, is performed +by dumpers. The `psycopg.adapt.Dumper` base class makes it easy to implement one: +you only need to implement the `~psycopg.abc.Dumper.dump()` method:: + + >>> from psycopg.adapt import Dumper + + >>> class XmlDumper(Dumper): + ... # Setting an OID is not necessary but can be helpful + ... oid = psycopg.adapters.types["xml"].oid + ... + ... def dump(self, elem): + ... return ET.tostring(elem) + + >>> # Register the dumper on the adapters of a context + >>> conn.adapters.register_dumper(ET.Element, XmlDumper) + + >>> # Now, in that context, it is possible to use ET.Element objects as parameters + >>> conn.execute("SELECT xpath('//title/text()', %s)", [elem]).fetchone()[0] + ['Manual'] + +Note that it is possible to use a `~psycopg.types.TypesRegistry`, exposed by +any `~psycopg.abc.AdaptContext`, to obtain information on builtin types, or +extension types if they have been registered on that context using the +`~psycopg.types.TypeInfo`\.\ `~psycopg.types.TypeInfo.register()` method. + + +.. _adapt-example-float: + +Example: PostgreSQL numeric to Python float +------------------------------------------- + +Normally PostgreSQL :sql:`numeric` values are converted to Python +`~decimal.Decimal` instances, because both the types allow fixed-precision +arithmetic and are not subject to rounding. + +Sometimes, however, you may want to perform floating-point math on +:sql:`numeric` values, and `!Decimal` may get in the way (maybe because it is +slower, or maybe because mixing `!float` and `!Decimal` values causes Python +errors). + +If you are fine with the potential loss of precision and you simply want to +receive :sql:`numeric` values as Python `!float`, you can register on +:sql:`numeric` the same `Loader` class used to load +:sql:`float4`\/:sql:`float8` values. Because the PostgreSQL textual +representation of both floats and decimal is the same, the two loaders are +compatible. + +.. code:: python + + conn = psycopg.connect() + + conn.execute("SELECT 123.45").fetchone()[0] + # Decimal('123.45') + + conn.adapters.register_loader("numeric", psycopg.types.numeric.FloatLoader) + + conn.execute("SELECT 123.45").fetchone()[0] + # 123.45 + +In this example the customised adaptation takes effect only on the connection +`!conn` and on any cursor created from it, not on other connections. + + +.. _adapt-example-inf-date: + +Example: handling infinity date +------------------------------- + +Suppose you want to work with the "infinity" date which is available in +PostgreSQL but not handled by Python: + +.. code:: python + + >>> conn.execute("SELECT 'infinity'::date").fetchone() + Traceback (most recent call last): + ... + DataError: date too large (after year 10K): 'infinity' + +One possibility would be to store Python's `datetime.date.max` as PostgreSQL +infinity. For this, let's create a subclass for the dumper and the loader and +register them in the working scope (globally or just on a connection or +cursor): + +.. code:: python + + from datetime import date + + # Subclass existing adapters so that the base case is handled normally. + from psycopg.types.datetime import DateLoader, DateDumper + + class InfDateDumper(DateDumper): + def dump(self, obj): + if obj == date.max: + return b"infinity" + elif obj == date.min: + return b"-infinity" + else: + return super().dump(obj) + + class InfDateLoader(DateLoader): + def load(self, data): + if data == b"infinity": + return date.max + elif data == b"-infinity": + return date.min + else: + return super().load(data) + + # The new classes can be registered globally, on a connection, on a cursor + cur.adapters.register_dumper(date, InfDateDumper) + cur.adapters.register_loader("date", InfDateLoader) + + cur.execute("SELECT %s::text, %s::text", [date(2020, 12, 31), date.max]).fetchone() + # ('2020-12-31', 'infinity') + cur.execute("SELECT '2020-12-31'::date, 'infinity'::date").fetchone() + # (datetime.date(2020, 12, 31), datetime.date(9999, 12, 31)) + + +Dumpers and loaders life cycle +------------------------------ + +Registering dumpers and loaders will instruct Psycopg to use them +in the queries to follow, in the context where they have been registered. + +When a query is performed on a `~psycopg.Cursor`, a +`~psycopg.adapt.Transformer` object is created as a local context to manage +adaptation during the query, instantiating the required dumpers and loaders +and dispatching the values to perform the wanted conversions from Python to +Postgres and back. + +- The `!Transformer` copies the adapters configuration from the `!Cursor`, + thus inheriting all the changes made to the global `psycopg.adapters` + configuration, the current `!Connection`, the `!Cursor`. + +- For every Python type passed as query argument, the `!Transformer` will + instantiate a `!Dumper`. Usually all the objects of the same type will be + converted by the same dumper instance. + + - According to the placeholder used (``%s``, ``%b``, ``%t``), Psycopg may + pick a binary or a text dumper. When using the ``%s`` "`~PyFormat.AUTO`" + format, if the same type has both a text and a binary dumper registered, + the last one registered by `~AdaptersMap.register_dumper()` will be used. + + - Sometimes, just looking at the Python type is not enough to decide the + best PostgreSQL type to use (for instance the PostgreSQL type of a Python + list depends on the objects it contains, whether to use an :sql:`integer` + or :sql:`bigint` depends on the number size...) In these cases the + mechanism provided by `~psycopg.abc.Dumper.get_key()` and + `~psycopg.abc.Dumper.upgrade()` is used to create more specific dumpers. + +- The query is executed. Upon successful request, the result is received as a + `~psycopg.pq.PGresult`. + +- For every OID returned by the query, the `!Transformer` will instantiate a + `!Loader`. All the values with the same OID will be converted by the same + loader instance. + +- Recursive types (e.g. Python lists, PostgreSQL arrays and composite types) + will use the same adaptation rules. + +As a consequence it is possible to perform certain choices only once per query +(e.g. looking up the connection encoding) and then call a fast-path operation +for each value to convert. + +Querying will fail if a Python object for which there isn't a `!Dumper` +registered (for the right `~psycopg.pq.Format`) is used as query parameter. +If the query returns a data type whose OID doesn't have a `!Loader`, the +value will be returned as a string (or bytes string for binary types). diff --git a/docs/advanced/async.rst b/docs/advanced/async.rst new file mode 100644 index 0000000..3620ab6 --- /dev/null +++ b/docs/advanced/async.rst @@ -0,0 +1,360 @@ +.. currentmodule:: psycopg + +.. index:: asyncio + +.. _async: + +Asynchronous operations +======================= + +Psycopg `~Connection` and `~Cursor` have counterparts `~AsyncConnection` and +`~AsyncCursor` supporting an `asyncio` interface. + +The design of the asynchronous objects is pretty much the same of the sync +ones: in order to use them you will only have to scatter the `!await` keyword +here and there. + +.. code:: python + + async with await psycopg.AsyncConnection.connect( + "dbname=test user=postgres") as aconn: + async with aconn.cursor() as acur: + await acur.execute( + "INSERT INTO test (num, data) VALUES (%s, %s)", + (100, "abc'def")) + await acur.execute("SELECT * FROM test") + await acur.fetchone() + # will return (1, 100, "abc'def") + async for record in acur: + print(record) + +.. versionchanged:: 3.1 + + `AsyncConnection.connect()` performs DNS name resolution in a non-blocking + way. + + .. warning:: + + Before version 3.1, `AsyncConnection.connect()` may still block on DNS + name resolution. To avoid that you should `set the hostaddr connection + parameter`__, or use the `~psycopg._dns.resolve_hostaddr_async()` to + do it automatically. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-PARAMKEYWORDS + +.. warning:: + + On Windows, Psycopg is not compatible with the default + `~asyncio.ProactorEventLoop`. Please use a different loop, for instance + the `~asyncio.SelectorEventLoop`. + + For instance, you can use, early in your program: + + .. parsed-literal:: + + `asyncio.set_event_loop_policy`\ ( + `asyncio.WindowsSelectorEventLoopPolicy`\ () + ) + + + +.. index:: with + +.. _async-with: + +`!with` async connections +------------------------- + +As seen in :ref:`the basic usage <usage>`, connections and cursors can act as +context managers, so you can run: + +.. code:: python + + with psycopg.connect("dbname=test user=postgres") as conn: + with conn.cursor() as cur: + cur.execute(...) + # the cursor is closed upon leaving the context + # the transaction is committed, the connection closed + +For asynchronous connections it's *almost* what you'd expect, but +not quite. Please note that `~Connection.connect()` and `~Connection.cursor()` +*don't return a context*: they are both factory methods which return *an +object which can be used as a context*. That's because there are several use +cases where it's useful to handle the objects manually and only `!close()` them +when required. + +As a consequence you cannot use `!async with connect()`: you have to do it in +two steps instead, as in + +.. code:: python + + aconn = await psycopg.AsyncConnection.connect() + async with aconn: + async with aconn.cursor() as cur: + await cur.execute(...) + +which can be condensed into `!async with await`: + +.. code:: python + + async with await psycopg.AsyncConnection.connect() as aconn: + async with aconn.cursor() as cur: + await cur.execute(...) + +...but no less than that: you still need to do the double async thing. + +Note that the `AsyncConnection.cursor()` function is not an `!async` function +(it never performs I/O), so you don't need an `!await` on it; as a consequence +you can use the normal `async with` context manager. + + +.. index:: Ctrl-C + +.. _async-ctrl-c: + +Interrupting async operations using Ctrl-C +------------------------------------------ + +If a long running operation is interrupted by a Ctrl-C on a normal connection +running in the main thread, the operation will be cancelled and the connection +will be put in error state, from which can be recovered with a normal +`~Connection.rollback()`. + +If the query is running in an async connection, a Ctrl-C will be likely +intercepted by the async loop and interrupt the whole program. In order to +emulate what normally happens with blocking connections, you can use +`asyncio's add_signal_handler()`__, to call `Connection.cancel()`: + +.. code:: python + + import asyncio + import signal + + async with await psycopg.AsyncConnection.connect() as conn: + loop.add_signal_handler(signal.SIGINT, conn.cancel) + ... + + +.. __: https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.add_signal_handler + + +.. index:: + pair: Asynchronous; Notifications + pair: LISTEN; SQL command + pair: NOTIFY; SQL command + +.. _async-messages: + +Server messages +--------------- + +PostgreSQL can send, together with the query results, `informative messages`__ +about the operation just performed, such as warnings or debug information. +Notices may be raised even if the operations are successful and don't indicate +an error. You are probably familiar with some of them, because they are +reported by :program:`psql`:: + + $ psql + =# ROLLBACK; + WARNING: there is no transaction in progress + ROLLBACK + +.. __: https://www.postgresql.org/docs/current/runtime-config-logging.html + #RUNTIME-CONFIG-SEVERITY-LEVELS + +Messages can be also sent by the `PL/pgSQL 'RAISE' statement`__ (at a level +lower than EXCEPTION, otherwise the appropriate `DatabaseError` will be +raised). The level of the messages received can be controlled using the +client_min_messages__ setting. + +.. __: https://www.postgresql.org/docs/current/plpgsql-errors-and-messages.html +.. __: https://www.postgresql.org/docs/current/runtime-config-client.html + #GUC-CLIENT-MIN-MESSAGES + + +By default, the messages received are ignored. If you want to process them on +the client you can use the `Connection.add_notice_handler()` function to +register a function that will be invoked whenever a message is received. The +message is passed to the callback as a `~errors.Diagnostic` instance, +containing all the information passed by the server, such as the message text +and the severity. The object is the same found on the `~psycopg.Error.diag` +attribute of the errors raised by the server: + +.. code:: python + + >>> import psycopg + + >>> def log_notice(diag): + ... print(f"The server says: {diag.severity} - {diag.message_primary}") + + >>> conn = psycopg.connect(autocommit=True) + >>> conn.add_notice_handler(log_notice) + + >>> cur = conn.execute("ROLLBACK") + The server says: WARNING - there is no transaction in progress + >>> print(cur.statusmessage) + ROLLBACK + +.. warning:: + + The `!Diagnostic` object received by the callback should not be used after + the callback function terminates, because its data is deallocated after + the callbacks have been processed. If you need to use the information + later please extract the attributes requested and forward them instead of + forwarding the whole `!Diagnostic` object. + + +.. index:: + pair: Asynchronous; Notifications + pair: LISTEN; SQL command + pair: NOTIFY; SQL command + +.. _async-notify: + +Asynchronous notifications +-------------------------- + +Psycopg allows asynchronous interaction with other database sessions using the +facilities offered by PostgreSQL commands |LISTEN|_ and |NOTIFY|_. Please +refer to the PostgreSQL documentation for examples about how to use this form +of communication. + +.. |LISTEN| replace:: :sql:`LISTEN` +.. _LISTEN: https://www.postgresql.org/docs/current/sql-listen.html +.. |NOTIFY| replace:: :sql:`NOTIFY` +.. _NOTIFY: https://www.postgresql.org/docs/current/sql-notify.html + +Because of the way sessions interact with notifications (see |NOTIFY|_ +documentation), you should keep the connection in `~Connection.autocommit` +mode if you wish to receive or send notifications in a timely manner. + +Notifications are received as instances of `Notify`. If you are reserving a +connection only to receive notifications, the simplest way is to consume the +`Connection.notifies` generator. The generator can be stopped using +`!close()`. + +.. note:: + + You don't need an `AsyncConnection` to handle notifications: a normal + blocking `Connection` is perfectly valid. + +The following example will print notifications and stop when one containing +the ``"stop"`` message is received. + +.. code:: python + + import psycopg + conn = psycopg.connect("", autocommit=True) + conn.execute("LISTEN mychan") + gen = conn.notifies() + for notify in gen: + print(notify) + if notify.payload == "stop": + gen.close() + print("there, I stopped") + +If you run some :sql:`NOTIFY` in a :program:`psql` session: + +.. code:: psql + + =# NOTIFY mychan, 'hello'; + NOTIFY + =# NOTIFY mychan, 'hey'; + NOTIFY + =# NOTIFY mychan, 'stop'; + NOTIFY + +You may get output from the Python process such as:: + + Notify(channel='mychan', payload='hello', pid=961823) + Notify(channel='mychan', payload='hey', pid=961823) + Notify(channel='mychan', payload='stop', pid=961823) + there, I stopped + +Alternatively, you can use `~Connection.add_notify_handler()` to register a +callback function, which will be invoked whenever a notification is received, +during the normal query processing; you will be then able to use the +connection normally. Please note that in this case notifications will not be +received immediately, but only during a connection operation, such as a query. + +.. code:: python + + conn.add_notify_handler(lambda n: print(f"got this: {n}")) + + # meanwhile in psql... + # =# NOTIFY mychan, 'hey'; + # NOTIFY + + print(conn.execute("SELECT 1").fetchone()) + # got this: Notify(channel='mychan', payload='hey', pid=961823) + # (1,) + + +.. index:: disconnections + +.. _disconnections: + +Detecting disconnections +------------------------ + +Sometimes it is useful to detect immediately when the connection with the +database is lost. One brutal way to do so is to poll a connection in a loop +running an endless stream of :sql:`SELECT 1`... *Don't* do so: polling is *so* +out of fashion. Besides, it is inefficient (unless what you really want is a +client-server generator of ones), it generates useless traffic and will only +detect a disconnection with an average delay of half the polling time. + +A more efficient and timely way to detect a server disconnection is to create +an additional connection and wait for a notification from the OS that this +connection has something to say: only then you can run some checks. You +can dedicate a thread (or an asyncio task) to wait on this connection: such +thread will perform no activity until awaken by the OS. + +In a normal (non asyncio) program you can use the `selectors` module. Because +the `!Connection` implements a `~Connection.fileno()` method you can just +register it as a file-like object. You can run such code in a dedicated thread +(and using a dedicated connection) if the rest of the program happens to have +something else to do too. + +.. code:: python + + import selectors + + sel = selectors.DefaultSelector() + sel.register(conn, selectors.EVENT_READ) + while True: + if not sel.select(timeout=60.0): + continue # No FD activity detected in one minute + + # Activity detected. Is the connection still ok? + try: + conn.execute("SELECT 1") + except psycopg.OperationalError: + # You were disconnected: do something useful such as panicking + logger.error("we lost our database!") + sys.exit(1) + +In an `asyncio` program you can dedicate a `~asyncio.Task` instead and do +something similar using `~asyncio.loop.add_reader`: + +.. code:: python + + import asyncio + + ev = asyncio.Event() + loop = asyncio.get_event_loop() + loop.add_reader(conn.fileno(), ev.set) + + while True: + try: + await asyncio.wait_for(ev.wait(), 60.0) + except asyncio.TimeoutError: + continue # No FD activity detected in one minute + + # Activity detected. Is the connection still ok? + try: + await conn.execute("SELECT 1") + except psycopg.OperationalError: + # Guess what happened + ... diff --git a/docs/advanced/cursors.rst b/docs/advanced/cursors.rst new file mode 100644 index 0000000..954d665 --- /dev/null +++ b/docs/advanced/cursors.rst @@ -0,0 +1,192 @@ +.. currentmodule:: psycopg + +.. index:: + single: Cursor + +.. _cursor-types: + +Cursor types +============ + +Psycopg can manage kinds of "cursors" which differ in where the state of a +query being processed is stored: :ref:`client-side-cursors` and +:ref:`server-side-cursors`. + +.. index:: + double: Cursor; Client-side + +.. _client-side-cursors: + +Client-side cursors +------------------- + +Client-side cursors are what Psycopg uses in its normal querying process. +They are implemented by the `Cursor` and `AsyncCursor` classes. In such +querying pattern, after a cursor sends a query to the server (usually calling +`~Cursor.execute()`), the server replies transferring to the client the whole +set of results requested, which is stored in the state of the same cursor and +from where it can be read from Python code (using methods such as +`~Cursor.fetchone()` and siblings). + +This querying process is very scalable because, after a query result has been +transmitted to the client, the server doesn't keep any state. Because the +results are already in the client memory, iterating its rows is very quick. + +The downside of this querying method is that the entire result has to be +transmitted completely to the client (with a time proportional to its size) +and the client needs enough memory to hold it, so it is only suitable for +reasonably small result sets. + + +.. index:: + double: Cursor; Client-binding + +.. _client-side-binding-cursors: + +Client-side-binding cursors +--------------------------- + +.. versionadded:: 3.1 + +The previously described :ref:`client-side cursors <client-side-cursors>` send +the query and the parameters separately to the server. This is the most +efficient way to process parametrised queries and allows to build several +features and optimizations. However, not all types of queries can be bound +server-side; in particular no Data Definition Language query can. See +:ref:`server-side-binding` for the description of these problems. + +The `ClientCursor` (and its `AsyncClientCursor` async counterpart) merge the +query on the client and send the query and the parameters merged together to +the server. This allows to parametrize any type of PostgreSQL statement, not +only queries (:sql:`SELECT`) and Data Manipulation statements (:sql:`INSERT`, +:sql:`UPDATE`, :sql:`DELETE`). + +Using `!ClientCursor`, Psycopg 3 behaviour will be more similar to `psycopg2` +(which only implements client-side binding) and could be useful to port +Psycopg 2 programs more easily to Psycopg 3. The objects in the `sql` module +allow for greater flexibility (for instance to parametrize a table name too, +not only values); however, for simple cases, a `!ClientCursor` could be the +right object. + +In order to obtain `!ClientCursor` from a connection, you can set its +`~Connection.cursor_factory` (at init time or changing its attribute +afterwards): + +.. code:: python + + from psycopg import connect, ClientCursor + + conn = psycopg.connect(DSN, cursor_factory=ClientCursor) + cur = conn.cursor() + # <psycopg.ClientCursor [no result] [IDLE] (database=piro) at 0x7fd977ae2880> + +If you need to create a one-off client-side-binding cursor out of a normal +connection, you can just use the `~ClientCursor` class passing the connection +as argument. + +.. code:: python + + conn = psycopg.connect(DSN) + cur = psycopg.ClientCursor(conn) + +.. warning:: + + Client-side cursors don't support :ref:`binary parameters and return + values <binary-data>` and don't support :ref:`prepared statements + <prepared-statements>`. + +.. tip:: + + The best use for client-side binding cursors is probably to port large + Psycopg 2 code to Psycopg 3, especially for programs making wide use of + Data Definition Language statements. + + The `psycopg.sql` module allows for more generic client-side query + composition, to mix client- and server-side parameters binding, and allows + to parametrize tables and fields names too, or entirely generic SQL + snippets. + +.. index:: + double: Cursor; Server-side + single: Portal + double: Cursor; Named + +.. _server-side-cursors: + +Server-side cursors +------------------- + +PostgreSQL has its own concept of *cursor* too (sometimes also called +*portal*). When a database cursor is created, the query is not necessarily +completely processed: the server might be able to produce results only as they +are needed. Only the results requested are transmitted to the client: if the +query result is very large but the client only needs the first few records it +is possible to transmit only them. + +The downside is that the server needs to keep track of the partially +processed results, so it uses more memory and resources on the server. + +Psycopg allows the use of server-side cursors using the classes `ServerCursor` +and `AsyncServerCursor`. They are usually created by passing the `!name` +parameter to the `~Connection.cursor()` method (reason for which, in +`!psycopg2`, they are usually called *named cursors*). The use of these classes +is similar to their client-side counterparts: their interface is the same, but +behind the scene they send commands to control the state of the cursor on the +server (for instance when fetching new records or when moving using +`~Cursor.scroll()`). + +Using a server-side cursor it is possible to process datasets larger than what +would fit in the client's memory. However for small queries they are less +efficient because it takes more commands to receive their result, so you +should use them only if you need to process huge results or if only a partial +result is needed. + +.. seealso:: + + Server-side cursors are created and managed by `ServerCursor` using SQL + commands such as DECLARE_, FETCH_, MOVE_. The PostgreSQL documentation + gives a good idea of what is possible to do with them. + + .. _DECLARE: https://www.postgresql.org/docs/current/sql-declare.html + .. _FETCH: https://www.postgresql.org/docs/current/sql-fetch.html + .. _MOVE: https://www.postgresql.org/docs/current/sql-move.html + + +.. _cursor-steal: + +"Stealing" an existing cursor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A Psycopg `ServerCursor` can be also used to consume a cursor which was +created in other ways than the :sql:`DECLARE` that `ServerCursor.execute()` +runs behind the scene. + +For instance if you have a `PL/pgSQL function returning a cursor`__: + +.. __: https://www.postgresql.org/docs/current/plpgsql-cursors.html + +.. code:: postgres + + CREATE FUNCTION reffunc(refcursor) RETURNS refcursor AS $$ + BEGIN + OPEN $1 FOR SELECT col FROM test; + RETURN $1; + END; + $$ LANGUAGE plpgsql; + +you can run a one-off command in the same connection to call it (e.g. using +`Connection.execute()`) in order to create the cursor on the server: + +.. code:: python + + conn.execute("SELECT reffunc('curname')") + +after which you can create a server-side cursor declared by the same name, and +directly call the fetch methods, skipping the `~ServerCursor.execute()` call: + +.. code:: python + + cur = conn.cursor('curname') + # no cur.execute() + for record in cur: # or cur.fetchone(), cur.fetchmany()... + # do something with record diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst new file mode 100644 index 0000000..6920bd7 --- /dev/null +++ b/docs/advanced/index.rst @@ -0,0 +1,21 @@ +.. _advanced: + +More advanced topics +==================== + +Once you have familiarised yourself with the :ref:`Psycopg basic operations +<basic>`, you can take a look at the chapter of this section for more advanced +usages. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + async + typing + rows + pool + cursors + adapt + prepare + pipeline diff --git a/docs/advanced/pipeline.rst b/docs/advanced/pipeline.rst new file mode 100644 index 0000000..980fea7 --- /dev/null +++ b/docs/advanced/pipeline.rst @@ -0,0 +1,324 @@ +.. currentmodule:: psycopg + +.. _pipeline-mode: + +Pipeline mode support +===================== + +.. versionadded:: 3.1 + +The *pipeline mode* allows PostgreSQL client applications to send a query +without having to read the result of the previously sent query. Taking +advantage of the pipeline mode, a client will wait less for the server, since +multiple queries/results can be sent/received in a single network roundtrip. +Pipeline mode can provide a significant performance boost to the application. + +Pipeline mode is most useful when the server is distant, i.e., network latency +(“ping time”) is high, and also when many small operations are being performed +in rapid succession. There is usually less benefit in using pipelined commands +when each query takes many multiples of the client/server round-trip time to +execute. A 100-statement operation run on a server 300 ms round-trip-time away +would take 30 seconds in network latency alone without pipelining; with +pipelining it may spend as little as 0.3 s waiting for results from the +server. + +The server executes statements, and returns results, in the order the client +sends them. The server will begin executing the commands in the pipeline +immediately, not waiting for the end of the pipeline. Note that results are +buffered on the server side; the server flushes that buffer when a +:ref:`synchronization point <pipeline-sync>` is established. + +.. seealso:: + + The PostgreSQL documentation about: + + - `pipeline mode`__ + - `extended query message flow`__ + + contains many details around when it is most useful to use the pipeline + mode and about errors management and interaction with transactions. + + .. __: https://www.postgresql.org/docs/current/libpq-pipeline-mode.html + .. __: https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY + + +Client-server messages flow +--------------------------- + +In order to understand better how the pipeline mode works, we should take a +closer look at the `PostgreSQL client-server message flow`__. + +During normal querying, each statement is transmitted by the client to the +server as a stream of request messages, terminating with a **Sync** message to +tell it that it should process the messages sent so far. The server will +execute the statement and describe the results back as a stream of messages, +terminating with a **ReadyForQuery**, telling the client that it may now send a +new query. + +For example, the statement (returning no result): + +.. code:: python + + conn.execute("INSERT INTO mytable (data) VALUES (%s)", ["hello"]) + +results in the following two groups of messages: + +.. table:: + :align: left + + +---------------+-----------------------------------------------------------+ + | Direction | Message | + +===============+===========================================================+ + | Python | - Parse ``INSERT INTO ... (VALUE $1)`` (skipped if | + | | :ref:`the statement is prepared <prepared-statements>`) | + | |>| | - Bind ``'hello'`` | + | | - Describe | + | PostgreSQL | - Execute | + | | - Sync | + +---------------+-----------------------------------------------------------+ + | PostgreSQL | - ParseComplete | + | | - BindComplete | + | |<| | - NoData | + | | - CommandComplete ``INSERT 0 1`` | + | Python | - ReadyForQuery | + +---------------+-----------------------------------------------------------+ + +and the query: + +.. code:: python + + conn.execute("SELECT data FROM mytable WHERE id = %s", [1]) + +results in the two groups of messages: + +.. table:: + :align: left + + +---------------+-----------------------------------------------------------+ + | Direction | Message | + +===============+===========================================================+ + | Python | - Parse ``SELECT data FROM mytable WHERE id = $1`` | + | | - Bind ``1`` | + | |>| | - Describe | + | | - Execute | + | PostgreSQL | - Sync | + +---------------+-----------------------------------------------------------+ + | PostgreSQL | - ParseComplete | + | | - BindComplete | + | |<| | - RowDescription ``data`` | + | | - DataRow ``hello`` | + | Python | - CommandComplete ``SELECT 1`` | + | | - ReadyForQuery | + +---------------+-----------------------------------------------------------+ + +The two statements, sent consecutively, pay the communication overhead four +times, once per leg. + +The pipeline mode allows the client to combine several operations in longer +streams of messages to the server, then to receive more than one response in a +single batch. If we execute the two operations above in a pipeline: + +.. code:: python + + with conn.pipeline(): + conn.execute("INSERT INTO mytable (data) VALUES (%s)", ["hello"]) + conn.execute("SELECT data FROM mytable WHERE id = %s", [1]) + +they will result in a single roundtrip between the client and the server: + +.. table:: + :align: left + + +---------------+-----------------------------------------------------------+ + | Direction | Message | + +===============+===========================================================+ + | Python | - Parse ``INSERT INTO ... (VALUE $1)`` | + | | - Bind ``'hello'`` | + | |>| | - Describe | + | | - Execute | + | PostgreSQL | - Parse ``SELECT data FROM mytable WHERE id = $1`` | + | | - Bind ``1`` | + | | - Describe | + | | - Execute | + | | - Sync (sent only once) | + +---------------+-----------------------------------------------------------+ + | PostgreSQL | - ParseComplete | + | | - BindComplete | + | |<| | - NoData | + | | - CommandComplete ``INSERT 0 1`` | + | Python | - ParseComplete | + | | - BindComplete | + | | - RowDescription ``data`` | + | | - DataRow ``hello`` | + | | - CommandComplete ``SELECT 1`` | + | | - ReadyForQuery (sent only once) | + +---------------+-----------------------------------------------------------+ + +.. |<| unicode:: U+25C0 +.. |>| unicode:: U+25B6 + +.. __: https://www.postgresql.org/docs/current/protocol-flow.html + + +.. _pipeline-usage: + +Pipeline mode usage +------------------- + +Psycopg supports the pipeline mode via the `Connection.pipeline()` method. The +method is a context manager: entering the ``with`` block yields a `Pipeline` +object. At the end of block, the connection resumes the normal operation mode. + +Within the pipeline block, you can use normally one or more cursors to execute +several operations, using `Connection.execute()`, `Cursor.execute()` and +`~Cursor.executemany()`. + +.. code:: python + + >>> with conn.pipeline(): + ... conn.execute("INSERT INTO mytable VALUES (%s)", ["hello"]) + ... with conn.cursor() as cur: + ... cur.execute("INSERT INTO othertable VALUES (%s)", ["world"]) + ... cur.executemany( + ... "INSERT INTO elsewhere VALUES (%s)", + ... [("one",), ("two",), ("four",)]) + +Unlike in normal mode, Psycopg will not wait for the server to receive the +result of each query; the client will receive results in batches when the +server flushes it output buffer. + +When a flush (or a sync) is performed, all pending results are sent back to +the cursors which executed them. If a cursor had run more than one query, it +will receive more than one result; results after the first will be available, +in their execution order, using `~Cursor.nextset()`: + +.. code:: python + + >>> with conn.pipeline(): + ... with conn.cursor() as cur: + ... cur.execute("INSERT INTO mytable (data) VALUES (%s) RETURNING *", ["hello"]) + ... cur.execute("INSERT INTO mytable (data) VALUES (%s) RETURNING *", ["world"]) + ... while True: + ... print(cur.fetchall()) + ... if not cur.nextset(): + ... break + + [(1, 'hello')] + [(2, 'world')] + +If any statement encounters an error, the server aborts the current +transaction and will not execute any subsequent command in the queue until the +next :ref:`synchronization point <pipeline-sync>`; a `~errors.PipelineAborted` +exception is raised for each such command. Query processing resumes after the +synchronization point. + +.. warning:: + + Certain features are not available in pipeline mode, including: + + - COPY is not supported in pipeline mode by PostgreSQL. + - `Cursor.stream()` doesn't make sense in pipeline mode (its job is the + opposite of batching!) + - `ServerCursor` are currently not implemented in pipeline mode. + +.. note:: + + Starting from Psycopg 3.1, `~Cursor.executemany()` makes use internally of + the pipeline mode; as a consequence there is no need to handle a pipeline + block just to call `!executemany()` once. + + +.. _pipeline-sync: + +Synchronization points +---------------------- + +Flushing query results to the client can happen either when a synchronization +point is established by Psycopg: + +- using the `Pipeline.sync()` method; +- on `Connection.commit()` or `~Connection.rollback()`; +- at the end of a `!Pipeline` block; +- possibly when opening a nested `!Pipeline` block; +- using a fetch method such as `Cursor.fetchone()` (which only flushes the + query but doesn't issue a Sync and doesn't reset a pipeline state error). + +The server might perform a flush on its own initiative, for instance when the +output buffer is full. + +Note that, even in :ref:`autocommit <autocommit>`, the server wraps the +statements sent in pipeline mode in an implicit transaction, which will be +only committed when the Sync is received. As such, a failure in a group of +statements will probably invalidate the effect of statements executed after +the previous Sync, and will propagate to the following Sync. + +For example, in the following block: + +.. code:: python + + >>> with psycopg.connect(autocommit=True) as conn: + ... with conn.pipeline() as p, conn.cursor() as cur: + ... try: + ... cur.execute("INSERT INTO mytable (data) VALUES (%s)", ["one"]) + ... cur.execute("INSERT INTO no_such_table (data) VALUES (%s)", ["two"]) + ... conn.execute("INSERT INTO mytable (data) VALUES (%s)", ["three"]) + ... p.sync() + ... except psycopg.errors.UndefinedTable: + ... pass + ... cur.execute("INSERT INTO mytable (data) VALUES (%s)", ["four"]) + +there will be an error in the block, ``relation "no_such_table" does not +exist`` caused by the insert ``two``, but probably raised by the `!sync()` +call. At at the end of the block, the table will contain: + +.. code:: text + + =# SELECT * FROM mytable; + +----+------+ + | id | data | + +----+------+ + | 2 | four | + +----+------+ + (1 row) + +because: + +- the value 1 of the sequence is consumed by the statement ``one``, but + the record discarded because of the error in the same implicit transaction; +- the statement ``three`` is not executed because the pipeline is aborted (so + it doesn't consume a sequence item); +- the statement ``four`` is executed with + success after the Sync has terminated the failed transaction. + +.. warning:: + + The exact Python statement where an exception caused by a server error is + raised is somewhat arbitrary: it depends on when the server flushes its + buffered result. + + If you want to make sure that a group of statements is applied atomically + by the server, do make use of transaction methods such as + `~Connection.commit()` or `~Connection.transaction()`: these methods will + also sync the pipeline and raise an exception if there was any error in + the commands executed so far. + + +The fine prints +--------------- + +.. warning:: + + The Pipeline mode is an experimental feature. + + Its behaviour, especially around error conditions and concurrency, hasn't + been explored as much as the normal request-response messages pattern, and + its async nature makes it inherently more complex. + + As we gain more experience and feedback (which is welcome), we might find + bugs and shortcomings forcing us to change the current interface or + behaviour. + +The pipeline mode is available on any currently supported PostgreSQL version, +but, in order to make use of it, the client must use a libpq from PostgreSQL +14 or higher. You can use `Pipeline.is_supported()` to make sure your client +has the right library. diff --git a/docs/advanced/pool.rst b/docs/advanced/pool.rst new file mode 100644 index 0000000..adea0a7 --- /dev/null +++ b/docs/advanced/pool.rst @@ -0,0 +1,332 @@ +.. currentmodule:: psycopg_pool + +.. _connection-pools: + +Connection pools +================ + +A `connection pool`__ is an object managing a set of connections and allowing +their use in functions needing one. Because the time to establish a new +connection can be relatively long, keeping connections open can reduce latency. + +.. __: https://en.wikipedia.org/wiki/Connection_pool + +This page explains a few basic concepts of Psycopg connection pool's +behaviour. Please refer to the `ConnectionPool` object API for details about +the pool operations. + +.. note:: The connection pool objects are distributed in a package separate + from the main `psycopg` package: use ``pip install "psycopg[pool]"`` or ``pip + install psycopg_pool`` to make the `psycopg_pool` package available. See + :ref:`pool-installation`. + + +Pool life cycle +--------------- + +A simple way to use the pool is to create a single instance of it, as a +global object, and to use this object in the rest of the program, allowing +other functions, modules, threads to use it:: + + # module db.py in your program + from psycopg_pool import ConnectionPool + + pool = ConnectionPool(conninfo, **kwargs) + # the pool starts connecting immediately. + + # in another module + from .db import pool + + def my_function(): + with pool.connection() as conn: + conn.execute(...) + +Ideally you may want to call `~ConnectionPool.close()` when the use of the +pool is finished. Failing to call `!close()` at the end of the program is not +terribly bad: probably it will just result in some warnings printed on stderr. +However, if you think that it's sloppy, you could use the `atexit` module to +have `!close()` called at the end of the program. + +If you want to avoid starting to connect to the database at import time, and +want to wait for the application to be ready, you can create the pool using +`!open=False`, and call the `~ConnectionPool.open()` and +`~ConnectionPool.close()` methods when the conditions are right. Certain +frameworks provide callbacks triggered when the program is started and stopped +(for instance `FastAPI startup/shutdown events`__): they are perfect to +initiate and terminate the pool operations:: + + pool = ConnectionPool(conninfo, open=False, **kwargs) + + @app.on_event("startup") + def open_pool(): + pool.open() + + @app.on_event("shutdown") + def close_pool(): + pool.close() + +.. __: https://fastapi.tiangolo.com/advanced/events/#events-startup-shutdown + +Creating a single pool as a global variable is not the mandatory use: your +program can create more than one pool, which might be useful to connect to +more than one database, or to provide different types of connections, for +instance to provide separate read/write and read-only connections. The pool +also acts as a context manager and is open and closed, if necessary, on +entering and exiting the context block:: + + from psycopg_pool import ConnectionPool + + with ConnectionPool(conninfo, **kwargs) as pool: + run_app(pool) + + # the pool is now closed + +When the pool is open, the pool's background workers start creating the +requested `!min_size` connections, while the constructor (or the `!open()` +method) returns immediately. This allows the program some leeway to start +before the target database is up and running. However, if your application is +misconfigured, or the network is down, it means that the program will be able +to start, but the threads requesting a connection will fail with a +`PoolTimeout` only after the timeout on `~ConnectionPool.connection()` is +expired. If this behaviour is not desirable (and you prefer your program to +crash hard and fast, if the surrounding conditions are not right, because +something else will respawn it) you should call the `~ConnectionPool.wait()` +method after creating the pool, or call `!open(wait=True)`: these methods will +block until the pool is full, or will raise a `PoolTimeout` exception if the +pool isn't ready within the allocated time. + + +Connections life cycle +---------------------- + +The pool background workers create connections according to the parameters +`!conninfo`, `!kwargs`, and `!connection_class` passed to `ConnectionPool` +constructor, invoking something like :samp:`{connection_class}({conninfo}, +**{kwargs})`. Once a connection is created it is also passed to the +`!configure()` callback, if provided, after which it is put in the pool (or +passed to a client requesting it, if someone is already knocking at the door). + +If a connection expires (it passes `!max_lifetime`), or is returned to the pool +in broken state, or is found closed by `~ConnectionPool.check()`), then the +pool will dispose of it and will start a new connection attempt in the +background. + + +Using connections from the pool +------------------------------- + +The pool can be used to request connections from multiple threads or +concurrent tasks - it is hardly useful otherwise! If more connections than the +ones available in the pool are requested, the requesting threads are queued +and are served a connection as soon as one is available, either because +another client has finished using it or because the pool is allowed to grow +(when `!max_size` > `!min_size`) and a new connection is ready. + +The main way to use the pool is to obtain a connection using the +`~ConnectionPool.connection()` context, which returns a `~psycopg.Connection` +or subclass:: + + with my_pool.connection() as conn: + conn.execute("what you want") + +The `!connection()` context behaves like the `~psycopg.Connection` object +context: at the end of the block, if there is a transaction open, it will be +committed, or rolled back if the context is exited with as exception. + +At the end of the block the connection is returned to the pool and shouldn't +be used anymore by the code which obtained it. If a `!reset()` function is +specified in the pool constructor, it is called on the connection before +returning it to the pool. Note that the `!reset()` function is called in a +worker thread, so that the thread which used the connection can keep its +execution without being slowed down by it. + + +Pool connection and sizing +-------------------------- + +A pool can have a fixed size (specifying no `!max_size` or `!max_size` = +`!min_size`) or a dynamic size (when `!max_size` > `!min_size`). In both +cases, as soon as the pool is created, it will try to acquire `!min_size` +connections in the background. + +If an attempt to create a connection fails, a new attempt will be made soon +after, using an exponential backoff to increase the time between attempts, +until a maximum of `!reconnect_timeout` is reached. When that happens, the pool +will call the `!reconnect_failed()` function, if provided to the pool, and just +start a new connection attempt. You can use this function either to send +alerts or to interrupt the program and allow the rest of your infrastructure +to restart it. + +If more than `!min_size` connections are requested concurrently, new ones are +created, up to `!max_size`. Note that the connections are always created by the +background workers, not by the thread asking for the connection: if a client +requests a new connection, and a previous client terminates its job before the +new connection is ready, the waiting client will be served the existing +connection. This is especially useful in scenarios where the time to establish +a connection dominates the time for which the connection is used (see `this +analysis`__, for instance). + +.. __: https://github.com/brettwooldridge/HikariCP/blob/dev/documents/ + Welcome-To-The-Jungle.md + +If a pool grows above `!min_size`, but its usage decreases afterwards, a number +of connections are eventually closed: one every time a connection is unused +after the `!max_idle` time specified in the pool constructor. + + +What's the right size for the pool? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Big question. Who knows. However, probably not as large as you imagine. Please +take a look at `this analysis`__ for some ideas. + +.. __: https://github.com/brettwooldridge/HikariCP/wiki/About-Pool-Sizing + +Something useful you can do is probably to use the +`~ConnectionPool.get_stats()` method and monitor the behaviour of your program +to tune the configuration parameters. The size of the pool can also be changed +at runtime using the `~ConnectionPool.resize()` method. + + +.. _null-pool: + +Null connection pools +--------------------- + +.. versionadded:: 3.1 + +Sometimes you may want leave the choice of using or not using a connection +pool as a configuration parameter of your application. For instance, you might +want to use a pool if you are deploying a "large instance" of your application +and can dedicate it a handful of connections; conversely you might not want to +use it if you deploy the application in several instances, behind a load +balancer, and/or using an external connection pool process such as PgBouncer. + +Switching between using or not using a pool requires some code change, because +the `ConnectionPool` API is different from the normal `~psycopg.connect()` +function and because the pool can perform additional connection configuration +(in the `!configure` parameter) that, if the pool is removed, should be +performed in some different code path of your application. + +The `!psycopg_pool` 3.1 package introduces the `NullConnectionPool` class. +This class has the same interface, and largely the same behaviour, of the +`!ConnectionPool`, but doesn't create any connection beforehand. When a +connection is returned, unless there are other clients already waiting, it +is closed immediately and not kept in the pool state. + +A null pool is not only a configuration convenience, but can also be used to +regulate the access to the server by a client program. If `!max_size` is set to +a value greater than 0, the pool will make sure that no more than `!max_size` +connections are created at any given time. If more clients ask for further +connections, they will be queued and served a connection as soon as a previous +client has finished using it, like for the basic pool. Other mechanisms to +throttle client requests (such as `!timeout` or `!max_waiting`) are respected +too. + +.. note:: + + Queued clients will be handed an already established connection, as soon + as a previous client has finished using it (and after the pool has + returned it to idle state and called `!reset()` on it, if necessary). + +Because normally (i.e. unless queued) every client will be served a new +connection, the time to obtain the connection is paid by the waiting client; +background workers are not normally involved in obtaining new connections. + + +Connection quality +------------------ + +The state of the connection is verified when a connection is returned to the +pool: if a connection is broken during its usage it will be discarded on +return and a new connection will be created. + +.. warning:: + + The health of the connection is not checked when the pool gives it to a + client. + +Why not? Because doing so would require an extra network roundtrip: we want to +save you from its latency. Before getting too angry about it, just think that +the connection can be lost any moment while your program is using it. As your +program should already be able to cope with a loss of a connection during its +process, it should be able to tolerate to be served a broken connection: +unpleasant but not the end of the world. + +.. warning:: + + The health of the connection is not checked when the connection is in the + pool. + +Does the pool keep a watchful eye on the quality of the connections inside it? +No, it doesn't. Why not? Because you will do it for us! Your program is only +a big ruse to make sure the connections are still alive... + +Not (entirely) trolling: if you are using a connection pool, we assume that +you are using and returning connections at a good pace. If the pool had to +check for the quality of a broken connection before your program notices it, +it should be polling each connection even faster than your program uses them. +Your database server wouldn't be amused... + +Can you do something better than that? Of course you can, there is always a +better way than polling. You can use the same recipe of :ref:`disconnections`, +reserving a connection and using a thread to monitor for any activity +happening on it. If any activity is detected, you can call the pool +`~ConnectionPool.check()` method, which will run a quick check on each +connection in the pool, removing the ones found in broken state, and using the +background workers to replace them with fresh ones. + +If you set up a similar check in your program, in case the database connection +is temporarily lost, we cannot do anything for the threads which had taken +already a connection from the pool, but no other thread should be served a +broken connection, because `!check()` would empty the pool and refill it with +working connections, as soon as they are available. + +Faster than you can say poll. Or pool. + + +.. _pool-stats: + +Pool stats +---------- + +The pool can return information about its usage using the methods +`~ConnectionPool.get_stats()` or `~ConnectionPool.pop_stats()`. Both methods +return the same values, but the latter reset the counters after its use. The +values can be sent to a monitoring system such as Graphite_ or Prometheus_. + +.. _Graphite: https://graphiteapp.org/ +.. _Prometheus: https://prometheus.io/ + +The following values should be provided, but please don't consider them as a +rigid interface: it is possible that they might change in the future. Keys +whose value is 0 may not be returned. + + +======================= ===================================================== +Metric Meaning +======================= ===================================================== + ``pool_min`` Current value for `~ConnectionPool.min_size` + ``pool_max`` Current value for `~ConnectionPool.max_size` + ``pool_size`` Number of connections currently managed by the pool + (in the pool, given to clients, being prepared) + ``pool_available`` Number of connections currently idle in the pool + ``requests_waiting`` Number of requests currently waiting in a queue to + receive a connection + ``usage_ms`` Total usage time of the connections outside the pool + ``requests_num`` Number of connections requested to the pool + ``requests_queued`` Number of requests queued because a connection wasn't + immediately available in the pool + ``requests_wait_ms`` Total time in the queue for the clients waiting + ``requests_errors`` Number of connection requests resulting in an error + (timeouts, queue full...) + ``returns_bad`` Number of connections returned to the pool in a bad + state + ``connections_num`` Number of connection attempts made by the pool to the + server + ``connections_ms`` Total time spent to establish connections with the + server + ``connections_errors`` Number of failed connection attempts + ``connections_lost`` Number of connections lost identified by + `~ConnectionPool.check()` +======================= ===================================================== diff --git a/docs/advanced/prepare.rst b/docs/advanced/prepare.rst new file mode 100644 index 0000000..e41bcae --- /dev/null +++ b/docs/advanced/prepare.rst @@ -0,0 +1,57 @@ +.. currentmodule:: psycopg + +.. index:: + single: Prepared statements + +.. _prepared-statements: + +Prepared statements +=================== + +Psycopg uses an automatic system to manage *prepared statements*. When a +query is prepared, its parsing and planning is stored in the server session, +so that further executions of the same query on the same connection (even with +different parameters) are optimised. + +A query is prepared automatically after it is executed more than +`~Connection.prepare_threshold` times on a connection. `!psycopg` will make +sure that no more than `~Connection.prepared_max` statements are planned: if +further queries are executed, the least recently used ones are deallocated and +the associated resources freed. + +Statement preparation can be controlled in several ways: + +- You can decide to prepare a query immediately by passing `!prepare=True` to + `Connection.execute()` or `Cursor.execute()`. The query is prepared, if it + wasn't already, and executed as prepared from its first use. + +- Conversely, passing `!prepare=False` to `!execute()` will avoid to prepare + the query, regardless of the number of times it is executed. The default for + the parameter is `!None`, meaning that the query is prepared if the + conditions described above are met. + +- You can disable the use of prepared statements on a connection by setting + its `~Connection.prepare_threshold` attribute to `!None`. + +.. versionchanged:: 3.1 + You can set `!prepare_threshold` as a `~Connection.connect()` keyword + parameter too. + +.. seealso:: + + The `PREPARE`__ PostgreSQL documentation contains plenty of details about + prepared statements in PostgreSQL. + + Note however that Psycopg doesn't use SQL statements such as + :sql:`PREPARE` and :sql:`EXECUTE`, but protocol level commands such as the + ones exposed by :pq:`PQsendPrepare`, :pq:`PQsendQueryPrepared`. + + .. __: https://www.postgresql.org/docs/current/sql-prepare.html + +.. warning:: + + Using external connection poolers, such as PgBouncer, is not compatible + with prepared statements, because the same client connection may change + the server session it refers to. If such middleware is used you should + disable prepared statements, by setting the `Connection.prepare_threshold` + attribute to `!None`. diff --git a/docs/advanced/rows.rst b/docs/advanced/rows.rst new file mode 100644 index 0000000..c23efe5 --- /dev/null +++ b/docs/advanced/rows.rst @@ -0,0 +1,116 @@ +.. currentmodule:: psycopg + +.. index:: row factories + +.. _row-factories: + +Row factories +============= + +Cursor's `fetch*` methods, by default, return the records received from the +database as tuples. This can be changed to better suit the needs of the +programmer by using custom *row factories*. + +The module `psycopg.rows` exposes several row factories ready to be used. For +instance, if you want to return your records as dictionaries, you can use +`~psycopg.rows.dict_row`:: + + >>> from psycopg.rows import dict_row + + >>> conn = psycopg.connect(DSN, row_factory=dict_row) + + >>> conn.execute("select 'John Doe' as name, 33 as age").fetchone() + {'name': 'John Doe', 'age': 33} + +The `!row_factory` parameter is supported by the `~Connection.connect()` +method and the `~Connection.cursor()` method. Later usage of `!row_factory` +overrides a previous one. It is also possible to change the +`Connection.row_factory` or `Cursor.row_factory` attributes to change what +they return:: + + >>> cur = conn.cursor(row_factory=dict_row) + >>> cur.execute("select 'John Doe' as name, 33 as age").fetchone() + {'name': 'John Doe', 'age': 33} + + >>> from psycopg.rows import namedtuple_row + >>> cur.row_factory = namedtuple_row + >>> cur.execute("select 'John Doe' as name, 33 as age").fetchone() + Row(name='John Doe', age=33) + +If you want to return objects of your choice you can use a row factory +*generator*, for instance `~psycopg.rows.class_row` or +`~psycopg.rows.args_row`, or you can :ref:`write your own row factory +<row-factory-create>`:: + + >>> from dataclasses import dataclass + + >>> @dataclass + ... class Person: + ... name: str + ... age: int + ... weight: Optional[int] = None + + >>> from psycopg.rows import class_row + >>> cur = conn.cursor(row_factory=class_row(Person)) + >>> cur.execute("select 'John Doe' as name, 33 as age").fetchone() + Person(name='John Doe', age=33, weight=None) + + +.. index:: + single: Row Maker + single: Row Factory + +.. _row-factory-create: + +Creating new row factories +-------------------------- + +A *row factory* is a callable that accepts a `Cursor` object and returns +another callable, a *row maker*, which takes raw data (as a sequence of +values) and returns the desired object. + +The role of the row factory is to inspect a query result (it is called after a +query is executed and properties such as `~Cursor.description` and +`~Cursor.pgresult` are available on the cursor) and to prepare a callable +which is efficient to call repeatedly (because, for instance, the names of the +columns are extracted, sanitised, and stored in local variables). + +Formally, these objects are represented by the `~psycopg.rows.RowFactory` and +`~psycopg.rows.RowMaker` protocols. + +`~RowFactory` objects can be implemented as a class, for instance: + +.. code:: python + + from typing import Any, Sequence + from psycopg import Cursor + + class DictRowFactory: + def __init__(self, cursor: Cursor[Any]): + self.fields = [c.name for c in cursor.description] + + def __call__(self, values: Sequence[Any]) -> dict[str, Any]: + return dict(zip(self.fields, values)) + +or as a plain function: + +.. code:: python + + def dict_row_factory(cursor: Cursor[Any]) -> RowMaker[dict[str, Any]]: + fields = [c.name for c in cursor.description] + + def make_row(values: Sequence[Any]) -> dict[str, Any]: + return dict(zip(fields, values)) + + return make_row + +These can then be used by specifying a `row_factory` argument in +`Connection.connect()`, `Connection.cursor()`, or by setting the +`Connection.row_factory` attribute. + +.. code:: python + + conn = psycopg.connect(row_factory=DictRowFactory) + cur = conn.execute("SELECT first_name, last_name, age FROM persons") + person = cur.fetchone() + print(f"{person['first_name']} {person['last_name']}") diff --git a/docs/advanced/typing.rst b/docs/advanced/typing.rst new file mode 100644 index 0000000..71b4e41 --- /dev/null +++ b/docs/advanced/typing.rst @@ -0,0 +1,180 @@ +.. currentmodule:: psycopg + +.. _static-typing: + +Static Typing +============= + +Psycopg source code is annotated according to :pep:`0484` type hints and is +checked using the current version of Mypy_ in ``--strict`` mode. + +If your application is checked using Mypy too you can make use of Psycopg +types to validate the correct use of Psycopg objects and of the data returned +by the database. + +.. _Mypy: http://mypy-lang.org/ + + +Generic types +------------- + +Psycopg `Connection` and `Cursor` objects are `~typing.Generic` objects and +support a `!Row` parameter which is the type of the records returned. + +By default methods such as `Cursor.fetchall()` return normal tuples of unknown +size and content. As such, the `connect()` function returns an object of type +`!psycopg.Connection[Tuple[Any, ...]]` and `Connection.cursor()` returns an +object of type `!psycopg.Cursor[Tuple[Any, ...]]`. If you are writing generic +plumbing code it might be practical to use annotations such as +`!Connection[Any]` and `!Cursor[Any]`. + +.. code:: python + + conn = psycopg.connect() # type is psycopg.Connection[Tuple[Any, ...]] + + cur = conn.cursor() # type is psycopg.Cursor[Tuple[Any, ...]] + + rec = cur.fetchone() # type is Optional[Tuple[Any, ...]] + + recs = cur.fetchall() # type is List[Tuple[Any, ...]] + + +.. _row-factory-static: + +Type of rows returned +--------------------- + +If you want to use connections and cursors returning your data as different +types, for instance as dictionaries, you can use the `!row_factory` argument +of the `~Connection.connect()` and the `~Connection.cursor()` method, which +will control what type of record is returned by the fetch methods of the +cursors and annotate the returned objects accordingly. See +:ref:`row-factories` for more details. + +.. code:: python + + dconn = psycopg.connect(row_factory=dict_row) + # dconn type is psycopg.Connection[Dict[str, Any]] + + dcur = conn.cursor(row_factory=dict_row) + dcur = dconn.cursor() + # dcur type is psycopg.Cursor[Dict[str, Any]] in both cases + + drec = dcur.fetchone() + # drec type is Optional[Dict[str, Any]] + + +.. _example-pydantic: + +Example: returning records as Pydantic models +--------------------------------------------- + +Using Pydantic_ it is possible to enforce static typing at runtime. Using a +Pydantic model factory the code can be checked statically using Mypy and +querying the database will raise an exception if the rows returned is not +compatible with the model. + +.. _Pydantic: https://pydantic-docs.helpmanual.io/ + +The following example can be checked with ``mypy --strict`` without reporting +any issue. Pydantic will also raise a runtime error in case the +`!Person` is used with a query that returns incompatible data. + +.. code:: python + + from datetime import date + from typing import Optional + + import psycopg + from psycopg.rows import class_row + from pydantic import BaseModel + + class Person(BaseModel): + id: int + first_name: str + last_name: str + dob: Optional[date] + + def fetch_person(id: int) -> Person: + with psycopg.connect() as conn: + with conn.cursor(row_factory=class_row(Person)) as cur: + cur.execute( + """ + SELECT id, first_name, last_name, dob + FROM (VALUES + (1, 'John', 'Doe', '2000-01-01'::date), + (2, 'Jane', 'White', NULL) + ) AS data (id, first_name, last_name, dob) + WHERE id = %(id)s; + """, + {"id": id}, + ) + obj = cur.fetchone() + + # reveal_type(obj) would return 'Optional[Person]' here + + if not obj: + raise KeyError(f"person {id} not found") + + # reveal_type(obj) would return 'Person' here + + return obj + + for id in [1, 2]: + p = fetch_person(id) + if p.dob: + print(f"{p.first_name} was born in {p.dob.year}") + else: + print(f"Who knows when {p.first_name} was born") + + +.. _literal-string: + +Checking literal strings in queries +----------------------------------- + +The `~Cursor.execute()` method and similar should only receive a literal +string as input, according to :pep:`675`. This means that the query should +come from a literal string in your code, not from an arbitrary string +expression. + +For instance, passing an argument to the query should be done via the second +argument to `!execute()`, not by string composition: + +.. code:: python + + def get_record(conn: psycopg.Connection[Any], id: int) -> Any: + cur = conn.execute("SELECT * FROM my_table WHERE id = %s" % id) # BAD! + return cur.fetchone() + + # the function should be implemented as: + + def get_record(conn: psycopg.Connection[Any], id: int) -> Any: + cur = conn.execute("select * FROM my_table WHERE id = %s", (id,)) + return cur.fetchone() + +If you are composing a query dynamically you should use the `sql.SQL` object +and similar to escape safely table and field names. The parameter of the +`!SQL()` object should be a literal string: + +.. code:: python + + def count_records(conn: psycopg.Connection[Any], table: str) -> int: + query = "SELECT count(*) FROM %s" % table # BAD! + return conn.execute(query).fetchone()[0] + + # the function should be implemented as: + + def count_records(conn: psycopg.Connection[Any], table: str) -> int: + query = sql.SQL("SELECT count(*) FROM {}").format(sql.Identifier(table)) + return conn.execute(query).fetchone()[0] + +At the time of writing, no Python static analyzer implements this check (`mypy +doesn't implement it`__, Pyre_ does, but `doesn't work with psycopg yet`__). +Once the type checkers support will be complete, the above bad statements +should be reported as errors. + +.. __: https://github.com/python/mypy/issues/12554 +.. __: https://github.com/facebook/pyre-check/issues/636 + +.. _Pyre: https://pyre-check.org/ diff --git a/docs/api/abc.rst b/docs/api/abc.rst new file mode 100644 index 0000000..9514e9b --- /dev/null +++ b/docs/api/abc.rst @@ -0,0 +1,75 @@ +`!abc` -- Psycopg abstract classes +================================== + +The module exposes Psycopg definitions which can be used for static type +checking. + +.. module:: psycopg.abc + +.. autoclass:: Dumper(cls, context=None) + + :param cls: The type that will be managed by this dumper. + :type cls: type + :param context: The context where the transformation is performed. If not + specified the conversion might be inaccurate, for instance it will not + be possible to know the connection encoding or the server date format. + :type context: `AdaptContext` or None + + A partial implementation of this protocol (implementing everything except + `dump()`) is available as `psycopg.adapt.Dumper`. + + .. autoattribute:: format + + .. automethod:: dump + + The format returned by dump shouldn't contain quotes or escaped + values. + + .. automethod:: quote + + .. tip:: + + This method will be used by `~psycopg.sql.Literal` to convert a + value client-side. + + This method only makes sense for text dumpers; the result of calling + it on a binary dumper is undefined. It might scratch your car, or burn + your cake. Don't tell me I didn't warn you. + + .. autoattribute:: oid + + If the OID is not specified, PostgreSQL will try to infer the type + from the context, but this may fail in some contexts and may require a + cast (e.g. specifying :samp:`%s::{type}` for its placeholder). + + You can use the `psycopg.adapters`\ ``.``\ + `~psycopg.adapt.AdaptersMap.types` registry to find the OID of builtin + types, and you can use `~psycopg.types.TypeInfo` to extend the + registry to custom types. + + .. automethod:: get_key + .. automethod:: upgrade + + +.. autoclass:: Loader(oid, context=None) + + :param oid: The type that will be managed by this dumper. + :type oid: int + :param context: The context where the transformation is performed. If not + specified the conversion might be inaccurate, for instance it will not + be possible to know the connection encoding or the server date format. + :type context: `AdaptContext` or None + + A partial implementation of this protocol (implementing everything except + `load()`) is available as `psycopg.adapt.Loader`. + + .. autoattribute:: format + + .. automethod:: load + + +.. autoclass:: AdaptContext + :members: + + .. seealso:: :ref:`adaptation` for an explanation about how contexts are + connected. diff --git a/docs/api/adapt.rst b/docs/api/adapt.rst new file mode 100644 index 0000000..e47816c --- /dev/null +++ b/docs/api/adapt.rst @@ -0,0 +1,91 @@ +`adapt` -- Types adaptation +=========================== + +.. module:: psycopg.adapt + +The `!psycopg.adapt` module exposes a set of objects useful for the +configuration of *data adaptation*, which is the conversion of Python objects +to PostgreSQL data types and back. + +These objects are useful if you need to configure data adaptation, i.e. +if you need to change the default way that Psycopg converts between types or +if you want to adapt custom data types and objects. You don't need this object +in the normal use of Psycopg. + +See :ref:`adaptation` for an overview of the Psycopg adaptation system. + +.. _abstract base class: https://docs.python.org/glossary.html#term-abstract-base-class + + +Dumpers and loaders +------------------- + +.. autoclass:: Dumper(cls, context=None) + + This is an `abstract base class`_, partially implementing the + `~psycopg.abc.Dumper` protocol. Subclasses *must* at least implement the + `.dump()` method and optionally override other members. + + .. automethod:: dump + + .. attribute:: format + :type: psycopg.pq.Format + :value: TEXT + + Class attribute. Set it to `~psycopg.pq.Format.BINARY` if the class + `dump()` methods converts the object to binary format. + + .. automethod:: quote + + .. automethod:: get_key + + .. automethod:: upgrade + + +.. autoclass:: Loader(oid, context=None) + + This is an `abstract base class`_, partially implementing the + `~psycopg.abc.Loader` protocol. Subclasses *must* at least implement the + `.load()` method and optionally override other members. + + .. automethod:: load + + .. attribute:: format + :type: psycopg.pq.Format + :value: TEXT + + Class attribute. Set it to `~psycopg.pq.Format.BINARY` if the class + `load()` methods converts the object from binary format. + + +Other objects used in adaptations +--------------------------------- + +.. autoclass:: PyFormat + :members: + + +.. autoclass:: AdaptersMap + + .. seealso:: :ref:`adaptation` for an explanation about how contexts are + connected. + + .. automethod:: register_dumper + .. automethod:: register_loader + + .. attribute:: types + + The object where to look up for types information (such as the mapping + between type names and oids in the specified context). + + :type: `~psycopg.types.TypesRegistry` + + .. automethod:: get_dumper + .. automethod:: get_dumper_by_oid + .. automethod:: get_loader + + +.. autoclass:: Transformer(context=None) + + :param context: The context where the transformer should operate. + :type context: `~psycopg.abc.AdaptContext` diff --git a/docs/api/connections.rst b/docs/api/connections.rst new file mode 100644 index 0000000..db25382 --- /dev/null +++ b/docs/api/connections.rst @@ -0,0 +1,489 @@ +.. currentmodule:: psycopg + +Connection classes +================== + +The `Connection` and `AsyncConnection` classes are the main wrappers for a +PostgreSQL database session. You can imagine them similar to a :program:`psql` +session. + +One of the differences compared to :program:`psql` is that a `Connection` +usually handles a transaction automatically: other sessions will not be able +to see the changes until you have committed them, more or less explicitly. +Take a look to :ref:`transactions` for the details. + + +The `!Connection` class +----------------------- + +.. autoclass:: Connection() + + This class implements a `DBAPI-compliant interface`__. It is what you want + to use if you write a "classic", blocking program (eventually using + threads or Eventlet/gevent for concurrency). If your program uses `asyncio` + you might want to use `AsyncConnection` instead. + + .. __: https://www.python.org/dev/peps/pep-0249/#connection-objects + + Connections behave as context managers: on block exit, the current + transaction will be committed (or rolled back, in case of exception) and + the connection will be closed. + + .. automethod:: connect + + :param conninfo: The `connection string`__ (a ``postgresql://`` url or + a list of ``key=value`` pairs) to specify where and how to connect. + :param kwargs: Further parameters specifying the connection string. + They override the ones specified in `!conninfo`. + :param autocommit: If `!True` don't start transactions automatically. + See :ref:`transactions` for details. + :param row_factory: The row factory specifying what type of records + to create fetching data (default: `~psycopg.rows.tuple_row()`). See + :ref:`row-factories` for details. + :param cursor_factory: Initial value for the `cursor_factory` attribute + of the connection (new in Psycopg 3.1). + :param prepare_threshold: Initial value for the `prepare_threshold` + attribute of the connection (new in Psycopg 3.1). + + More specialized use: + + :param context: A context to copy the initial adapters configuration + from. It might be an `~psycopg.adapt.AdaptersMap` with customized + loaders and dumpers, used as a template to create several connections. + See :ref:`adaptation` for further details. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING + + This method is also aliased as `psycopg.connect()`. + + .. seealso:: + + - the list of `the accepted connection parameters`__ + - the `environment variables`__ affecting connection + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS + .. __: https://www.postgresql.org/docs/current/libpq-envars.html + + .. versionchanged:: 3.1 + added `!prepare_threshold` and `!cursor_factory` parameters. + + .. automethod:: close + + .. note:: + + You can use:: + + with psycopg.connect() as conn: + ... + + to close the connection automatically when the block is exited. + See :ref:`with-connection`. + + .. autoattribute:: closed + .. autoattribute:: broken + + .. method:: cursor(*, binary: bool = False, \ + row_factory: Optional[RowFactory] = None) -> Cursor + .. method:: cursor(name: str, *, binary: bool = False, \ + row_factory: Optional[RowFactory] = None, \ + scrollable: Optional[bool] = None, withhold: bool = False) -> ServerCursor + :noindex: + + Return a new cursor to send commands and queries to the connection. + + :param name: If not specified create a client-side cursor, if + specified create a server-side cursor. See + :ref:`cursor-types` for details. + :param binary: If `!True` return binary values from the database. All + the types returned by the query must have a binary + loader. See :ref:`binary-data` for details. + :param row_factory: If specified override the `row_factory` set on the + connection. See :ref:`row-factories` for details. + :param scrollable: Specify the `~ServerCursor.scrollable` property of + the server-side cursor created. + :param withhold: Specify the `~ServerCursor.withhold` property of + the server-side cursor created. + :return: A cursor of the class specified by `cursor_factory` (or + `server_cursor_factory` if `!name` is specified). + + .. note:: + + You can use:: + + with conn.cursor() as cur: + ... + + to close the cursor automatically when the block is exited. + + .. autoattribute:: cursor_factory + + The type, or factory function, returned by `cursor()` and `execute()`. + + Default is `psycopg.Cursor`. + + .. autoattribute:: server_cursor_factory + + The type, or factory function, returned by `cursor()` when a name is + specified. + + Default is `psycopg.ServerCursor`. + + .. autoattribute:: row_factory + + The row factory defining the type of rows returned by + `~Cursor.fetchone()` and the other cursor fetch methods. + + The default is `~psycopg.rows.tuple_row`, which means that the fetch + methods will return simple tuples. + + .. seealso:: See :ref:`row-factories` for details about defining the + objects returned by cursors. + + .. automethod:: execute + + :param query: The query to execute. + :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed` + :param params: The parameters to pass to the query, if any. + :type params: Sequence or Mapping + :param prepare: Force (`!True`) or disallow (`!False`) preparation of + the query. By default (`!None`) prepare automatically. See + :ref:`prepared-statements`. + :param binary: If `!True` the cursor will return binary values from the + database. All the types returned by the query must have a binary + loader. See :ref:`binary-data` for details. + + The method simply creates a `Cursor` instance, `~Cursor.execute()` the + query requested, and returns it. + + See :ref:`query-parameters` for all the details about executing + queries. + + .. automethod:: pipeline + + The method is a context manager: you should call it using:: + + with conn.pipeline() as p: + ... + + At the end of the block, a synchronization point is established and + the connection returns in normal mode. + + You can call the method recursively from within a pipeline block. + Innermost blocks will establish a synchronization point on exit, but + pipeline mode will be kept until the outermost block exits. + + See :ref:`pipeline-mode` for details. + + .. versionadded:: 3.1 + + + .. rubric:: Transaction management methods + + For details see :ref:`transactions`. + + .. automethod:: commit + .. automethod:: rollback + .. automethod:: transaction + + .. note:: + + The method must be called with a syntax such as:: + + with conn.transaction(): + ... + + with conn.transaction() as tx: + ... + + The latter is useful if you need to interact with the + `Transaction` object. See :ref:`transaction-context` for details. + + Inside a transaction block it will not be possible to call `commit()` + or `rollback()`. + + .. autoattribute:: autocommit + + The property is writable for sync connections, read-only for async + ones: you should call `!await` `~AsyncConnection.set_autocommit` + :samp:`({value})` instead. + + The following three properties control the characteristics of new + transactions. See :ref:`transaction-characteristics` for details. + + .. autoattribute:: isolation_level + + `!None` means use the default set in the default_transaction_isolation__ + configuration parameter of the server. + + .. __: https://www.postgresql.org/docs/current/runtime-config-client.html + #GUC-DEFAULT-TRANSACTION-ISOLATION + + .. autoattribute:: read_only + + `!None` means use the default set in the default_transaction_read_only__ + configuration parameter of the server. + + .. __: https://www.postgresql.org/docs/current/runtime-config-client.html + #GUC-DEFAULT-TRANSACTION-READ-ONLY + + .. autoattribute:: deferrable + + `!None` means use the default set in the default_transaction_deferrable__ + configuration parameter of the server. + + .. __: https://www.postgresql.org/docs/current/runtime-config-client.html + #GUC-DEFAULT-TRANSACTION-DEFERRABLE + + + .. rubric:: Checking and configuring the connection state + + .. attribute:: pgconn + :type: psycopg.pq.PGconn + + The `~pq.PGconn` libpq connection wrapper underlying the `!Connection`. + + It can be used to send low level commands to PostgreSQL and access + features not currently wrapped by Psycopg. + + .. autoattribute:: info + + .. autoattribute:: prepare_threshold + + See :ref:`prepared-statements` for details. + + + .. autoattribute:: prepared_max + + If more queries need to be prepared, old ones are deallocated__. + + .. __: https://www.postgresql.org/docs/current/sql-deallocate.html + + + .. rubric:: Methods you can use to do something cool + + .. automethod:: cancel + + .. automethod:: notifies + + Notifies are received after using :sql:`LISTEN` in a connection, when + any sessions in the database generates a :sql:`NOTIFY` on one of the + listened channels. + + .. automethod:: add_notify_handler + + See :ref:`async-notify` for details. + + .. automethod:: remove_notify_handler + + .. automethod:: add_notice_handler + + See :ref:`async-messages` for details. + + .. automethod:: remove_notice_handler + + .. automethod:: fileno + + + .. _tpc-methods: + + .. rubric:: Two-Phase Commit support methods + + .. versionadded:: 3.1 + + .. seealso:: :ref:`two-phase-commit` for an introductory explanation of + these methods. + + .. automethod:: xid + + .. automethod:: tpc_begin + + :param xid: The id of the transaction + :type xid: Xid or str + + This method should be called outside of a transaction (i.e. nothing + may have executed since the last `commit()` or `rollback()` and + `~ConnectionInfo.transaction_status` is `~pq.TransactionStatus.IDLE`). + + Furthermore, it is an error to call `!commit()` or `!rollback()` + within the TPC transaction: in this case a `ProgrammingError` + is raised. + + The `!xid` may be either an object returned by the `xid()` method or a + plain string: the latter allows to create a transaction using the + provided string as PostgreSQL transaction id. See also + `tpc_recover()`. + + + .. automethod:: tpc_prepare + + A `ProgrammingError` is raised if this method is used outside of a TPC + transaction. + + After calling `!tpc_prepare()`, no statements can be executed until + `tpc_commit()` or `tpc_rollback()` will be + called. + + .. seealso:: The |PREPARE TRANSACTION|_ PostgreSQL command. + + .. |PREPARE TRANSACTION| replace:: :sql:`PREPARE TRANSACTION` + .. _PREPARE TRANSACTION: https://www.postgresql.org/docs/current/static/sql-prepare-transaction.html + + + .. automethod:: tpc_commit + + :param xid: The id of the transaction + :type xid: Xid or str + + When called with no arguments, `!tpc_commit()` commits a TPC + transaction previously prepared with `tpc_prepare()`. + + If `!tpc_commit()` is called prior to `!tpc_prepare()`, a single phase + commit is performed. A transaction manager may choose to do this if + only a single resource is participating in the global transaction. + + When called with a transaction ID `!xid`, the database commits the + given transaction. If an invalid transaction ID is provided, a + `ProgrammingError` will be raised. This form should be called outside + of a transaction, and is intended for use in recovery. + + On return, the TPC transaction is ended. + + .. seealso:: The |COMMIT PREPARED|_ PostgreSQL command. + + .. |COMMIT PREPARED| replace:: :sql:`COMMIT PREPARED` + .. _COMMIT PREPARED: https://www.postgresql.org/docs/current/static/sql-commit-prepared.html + + + .. automethod:: tpc_rollback + + :param xid: The id of the transaction + :type xid: Xid or str + + When called with no arguments, `!tpc_rollback()` rolls back a TPC + transaction. It may be called before or after `tpc_prepare()`. + + When called with a transaction ID `!xid`, it rolls back the given + transaction. If an invalid transaction ID is provided, a + `ProgrammingError` is raised. This form should be called outside of a + transaction, and is intended for use in recovery. + + On return, the TPC transaction is ended. + + .. seealso:: The |ROLLBACK PREPARED|_ PostgreSQL command. + + .. |ROLLBACK PREPARED| replace:: :sql:`ROLLBACK PREPARED` + .. _ROLLBACK PREPARED: https://www.postgresql.org/docs/current/static/sql-rollback-prepared.html + + + .. automethod:: tpc_recover + + Returns a list of `Xid` representing pending transactions, suitable + for use with `tpc_commit()` or `tpc_rollback()`. + + If a transaction was not initiated by Psycopg, the returned Xids will + have attributes `~Xid.format_id` and `~Xid.bqual` set to `!None` and + the `~Xid.gtrid` set to the PostgreSQL transaction ID: such Xids are + still usable for recovery. Psycopg uses the same algorithm of the + `PostgreSQL JDBC driver`__ to encode a XA triple in a string, so + transactions initiated by a program using such driver should be + unpacked correctly. + + .. __: https://jdbc.postgresql.org/ + + Xids returned by `!tpc_recover()` also have extra attributes + `~Xid.prepared`, `~Xid.owner`, `~Xid.database` populated with the + values read from the server. + + .. seealso:: the |pg_prepared_xacts|_ system view. + + .. |pg_prepared_xacts| replace:: `pg_prepared_xacts` + .. _pg_prepared_xacts: https://www.postgresql.org/docs/current/static/view-pg-prepared-xacts.html + + +The `!AsyncConnection` class +---------------------------- + +.. autoclass:: AsyncConnection() + + This class implements a DBAPI-inspired interface, with all the blocking + methods implemented as coroutines. Unless specified otherwise, + non-blocking methods are shared with the `Connection` class. + + The following methods have the same behaviour of the matching `!Connection` + methods, but should be called using the `await` keyword. + + .. automethod:: connect + + .. versionchanged:: 3.1 + + Automatically resolve domain names asynchronously. In previous + versions, name resolution blocks, unless the `!hostaddr` + parameter is specified, or the `~psycopg._dns.resolve_hostaddr_async()` + function is used. + + .. automethod:: close + + .. note:: You can use ``async with`` to close the connection + automatically when the block is exited, but be careful about + the async quirkness: see :ref:`async-with` for details. + + .. method:: cursor(*, binary: bool = False, \ + row_factory: Optional[RowFactory] = None) -> AsyncCursor + .. method:: cursor(name: str, *, binary: bool = False, \ + row_factory: Optional[RowFactory] = None, \ + scrollable: Optional[bool] = None, withhold: bool = False) -> AsyncServerCursor + :noindex: + + .. note:: + + You can use:: + + async with conn.cursor() as cur: + ... + + to close the cursor automatically when the block is exited. + + .. autoattribute:: cursor_factory + + Default is `psycopg.AsyncCursor`. + + .. autoattribute:: server_cursor_factory + + Default is `psycopg.AsyncServerCursor`. + + .. autoattribute:: row_factory + + .. automethod:: execute + + .. automethod:: pipeline + + .. note:: + + It must be called as:: + + async with conn.pipeline() as p: + ... + + .. automethod:: commit + .. automethod:: rollback + + .. automethod:: transaction + + .. note:: + + It must be called as:: + + async with conn.transaction() as tx: + ... + + .. automethod:: notifies + .. automethod:: set_autocommit + .. automethod:: set_isolation_level + .. automethod:: set_read_only + .. automethod:: set_deferrable + + .. automethod:: tpc_prepare + .. automethod:: tpc_commit + .. automethod:: tpc_rollback + .. automethod:: tpc_recover diff --git a/docs/api/conninfo.rst b/docs/api/conninfo.rst new file mode 100644 index 0000000..9e5b01d --- /dev/null +++ b/docs/api/conninfo.rst @@ -0,0 +1,24 @@ +.. _psycopg.conninfo: + +`conninfo` -- manipulate connection strings +=========================================== + +This module contains a few utility functions to manipulate database +connection strings. + +.. module:: psycopg.conninfo + +.. autofunction:: conninfo_to_dict + + .. code:: python + + >>> conninfo_to_dict("postgres://jeff@example.com/db", user="piro") + {'user': 'piro', 'dbname': 'db', 'host': 'example.com'} + + +.. autofunction:: make_conninfo + + .. code:: python + + >>> make_conninfo("dbname=db user=jeff", user="piro", port=5432) + 'dbname=db user=piro port=5432' diff --git a/docs/api/copy.rst b/docs/api/copy.rst new file mode 100644 index 0000000..81a96e2 --- /dev/null +++ b/docs/api/copy.rst @@ -0,0 +1,117 @@ +.. currentmodule:: psycopg + +COPY-related objects +==================== + +The main objects (`Copy`, `AsyncCopy`) present the main interface to exchange +data during a COPY operations. These objects are normally obtained by the +methods `Cursor.copy()` and `AsyncCursor.copy()`; however, they can be also +created directly, for instance to write to a destination which is not a +database (e.g. using a `~psycopg.copy.FileWriter`). + +See :ref:`copy` for details. + + +Main Copy objects +----------------- + +.. autoclass:: Copy() + + The object is normally returned by `!with` `Cursor.copy()`. + + .. automethod:: write_row + + The data in the tuple will be converted as configured on the cursor; + see :ref:`adaptation` for details. + + .. automethod:: write + .. automethod:: read + + Instead of using `!read()` you can iterate on the `!Copy` object to + read its data row by row, using ``for row in copy: ...``. + + .. automethod:: rows + + Equivalent of iterating on `read_row()` until it returns `!None` + + .. automethod:: read_row + .. automethod:: set_types + + +.. autoclass:: AsyncCopy() + + The object is normally returned by ``async with`` `AsyncCursor.copy()`. + Its methods are similar to the ones of the `Copy` object but offering an + `asyncio` interface (`await`, `async for`, `async with`). + + .. automethod:: write_row + .. automethod:: write + .. automethod:: read + + Instead of using `!read()` you can iterate on the `!AsyncCopy` object + to read its data row by row, using ``async for row in copy: ...``. + + .. automethod:: rows + + Use it as `async for record in copy.rows():` ... + + .. automethod:: read_row + + +.. _copy-writers: + +Writer objects +-------------- + +.. currentmodule:: psycopg.copy + +.. versionadded:: 3.1 + +Copy writers are helper objects to specify where to write COPY-formatted data. +By default, data is written to the database (using the `LibpqWriter`). It is +possible to write copy-data for offline use by using a `FileWriter`, or to +customize further writing by implementing your own `Writer` or `AsyncWriter` +subclass. + +Writers instances can be used passing them to the cursor +`~psycopg.Cursor.copy()` method or to the `~psycopg.Copy` constructor, as the +`!writer` argument. + +.. autoclass:: Writer + + This is an abstract base class: subclasses are required to implement their + `write()` method. + + .. automethod:: write + .. automethod:: finish + + +.. autoclass:: LibpqWriter + + This is the writer used by default if none is specified. + + +.. autoclass:: FileWriter + + This writer should be used without executing a :sql:`COPY` operation on + the database. For example, if `records` is a list of tuples containing + data to save in COPY format to a file (e.g. for later import), it can be + used as: + + .. code:: python + + with open("target-file.pgcopy", "wb") as f: + with Copy(cur, writer=FileWriter(f)) as copy: + for record in records + copy.write_row(record) + + +.. autoclass:: AsyncWriter + + This class methods have the same semantics of the ones of `Writer`, but + offer an async interface. + + .. automethod:: write + .. automethod:: finish + +.. autoclass:: AsyncLibpqWriter diff --git a/docs/api/crdb.rst b/docs/api/crdb.rst new file mode 100644 index 0000000..de8344e --- /dev/null +++ b/docs/api/crdb.rst @@ -0,0 +1,120 @@ +`crdb` -- CockroachDB support +============================= + +.. module:: psycopg.crdb + +.. versionadded:: 3.1 + +CockroachDB_ is a distributed database using the same fronted-backend protocol +of PostgreSQL. As such, Psycopg can be used to write Python programs +interacting with CockroachDB. + +.. _CockroachDB: https://www.cockroachlabs.com/ + +Opening a connection to a CRDB database using `psycopg.connect()` provides a +largely working object. However, using the `psycopg.crdb.connect()` function +instead, Psycopg will create more specialised objects and provide a types +mapping tweaked on the CockroachDB data model. + + +.. _crdb-differences: + +Main differences from PostgreSQL +-------------------------------- + +CockroachDB behaviour is `different from PostgreSQL`__: please refer to the +database documentation for details. These are some of the main differences +affecting Psycopg behaviour: + +.. __: https://www.cockroachlabs.com/docs/stable/postgresql-compatibility.html + +- `~psycopg.Connection.cancel()` doesn't work before CockroachDB 22.1. On + older versions, you can use `CANCEL QUERY`_ instead (but from a different + connection). + +- :ref:`server-side-cursors` are well supported only from CockroachDB 22.1.3. + +- `~psycopg.ConnectionInfo.backend_pid` is only populated from CockroachDB + 22.1. Note however that you cannot use the PID to terminate the session; use + `SHOW session_id`_ to find the id of a session, which you may terminate with + `CANCEL SESSION`_ in lieu of PostgreSQL's :sql:`pg_terminate_backend()`. + +- Several data types are missing or slightly different from PostgreSQL (see + `adapters` for an overview of the differences). + +- The :ref:`two-phase commit protocol <two-phase-commit>` is not supported. + +- :sql:`LISTEN` and :sql:`NOTIFY` are not supported. However the `CHANGEFEED`_ + command, in conjunction with `~psycopg.Cursor.stream()`, can provide push + notifications. + +.. _CANCEL QUERY: https://www.cockroachlabs.com/docs/stable/cancel-query.html +.. _SHOW session_id: https://www.cockroachlabs.com/docs/stable/show-vars.html +.. _CANCEL SESSION: https://www.cockroachlabs.com/docs/stable/cancel-session.html +.. _CHANGEFEED: https://www.cockroachlabs.com/docs/stable/changefeed-for.html + + +.. _crdb-objects: + +CockroachDB-specific objects +---------------------------- + +.. autofunction:: connect + + This is an alias of the class method `CrdbConnection.connect`. + + If you need an asynchronous connection use the `AsyncCrdbConnection.connect()` + method instead. + + +.. autoclass:: CrdbConnection + + `psycopg.Connection` subclass. + + .. automethod:: is_crdb + + :param conn: the connection to check + :type conn: `~psycopg.Connection`, `~psycopg.AsyncConnection`, `~psycopg.pq.PGconn` + + +.. autoclass:: AsyncCrdbConnection + + `psycopg.AsyncConnection` subclass. + + +.. autoclass:: CrdbConnectionInfo + + The object is returned by the `~psycopg.Connection.info` attribute of + `CrdbConnection` and `AsyncCrdbConnection`. + + The object behaves like `!ConnectionInfo`, with the following differences: + + .. autoattribute:: vendor + + The `CockroachDB` string. + + .. autoattribute:: server_version + + +.. data:: adapters + + The default adapters map establishing how Python and CockroachDB types are + converted into each other. + + The map is used as a template when new connections are created, using + `psycopg.crdb.connect()` (similarly to the way `psycopg.adapters` is used + as template for new PostgreSQL connections). + + This registry contains only the types and adapters supported by + CockroachDB. Several PostgreSQL types and adapters are missing or + different from PostgreSQL, among which: + + - Composite types + - :sql:`range`, :sql:`multirange` types + - The :sql:`hstore` type + - Geometric types + - Nested arrays + - Arrays of :sql:`jsonb` + - The :sql:`cidr` data type + - The :sql:`json` type is an alias for :sql:`jsonb` + - The :sql:`int` type is an alias for :sql:`int8`, not `int4`. diff --git a/docs/api/cursors.rst b/docs/api/cursors.rst new file mode 100644 index 0000000..9c5b478 --- /dev/null +++ b/docs/api/cursors.rst @@ -0,0 +1,517 @@ +.. currentmodule:: psycopg + +Cursor classes +============== + +The `Cursor` and `AsyncCursor` classes are the main objects to send commands +to a PostgreSQL database session. They are normally created by the +connection's `~Connection.cursor()` method. + +Using the `!name` parameter on `!cursor()` will create a `ServerCursor` or +`AsyncServerCursor`, which can be used to retrieve partial results from a +database. + +A `Connection` can create several cursors, but only one at time can perform +operations, so they are not the best way to achieve parallelism (you may want +to operate with several connections instead). All the cursors on the same +connection have a view of the same session, so they can see each other's +uncommitted data. + + +The `!Cursor` class +------------------- + +.. autoclass:: Cursor + + This class implements a `DBAPI-compliant interface`__. It is what the + classic `Connection.cursor()` method returns. `AsyncConnection.cursor()` + will create instead `AsyncCursor` objects, which have the same set of + method but expose an `asyncio` interface and require `!async` and + `!await` keywords to operate. + + .. __: dbapi-cursor_ + .. _dbapi-cursor: https://www.python.org/dev/peps/pep-0249/#cursor-objects + + + Cursors behave as context managers: on block exit they are closed and + further operation will not be possible. Closing a cursor will not + terminate a transaction or a session though. + + .. attribute:: connection + :type: Connection + + The connection this cursor is using. + + .. automethod:: close + + .. note:: + + You can use:: + + with conn.cursor() as cur: + ... + + to close the cursor automatically when the block is exited. See + :ref:`usage`. + + .. autoattribute:: closed + + .. rubric:: Methods to send commands + + .. automethod:: execute + + :param query: The query to execute. + :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed` + :param params: The parameters to pass to the query, if any. + :type params: Sequence or Mapping + :param prepare: Force (`!True`) or disallow (`!False`) preparation of + the query. By default (`!None`) prepare automatically. See + :ref:`prepared-statements`. + :param binary: Specify whether the server should return data in binary + format (`!True`) or in text format (`!False`). By default + (`!None`) return data as requested by the cursor's `~Cursor.format`. + + Return the cursor itself, so that it will be possible to chain a fetch + operation after the call. + + See :ref:`query-parameters` for all the details about executing + queries. + + .. versionchanged:: 3.1 + + The `query` argument must be a `~typing.StringLiteral`. If you + need to compose a query dynamically, please use `sql.SQL` and + related objects. + + See :pep:`675` for details. + + .. automethod:: executemany + + :param query: The query to execute + :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed` + :param params_seq: The parameters to pass to the query + :type params_seq: Sequence of Sequences or Mappings + :param returning: If `!True`, fetch the results of the queries executed + :type returning: `!bool` + + This is more efficient than performing separate queries, but in case of + several :sql:`INSERT` (and with some SQL creativity for massive + :sql:`UPDATE` too) you may consider using `copy()`. + + If the queries return data you want to read (e.g. when executing an + :sql:`INSERT ... RETURNING` or a :sql:`SELECT` with a side-effect), + you can specify `!returning=True`; the results will be available in + the cursor's state and can be read using `fetchone()` and similar + methods. Each input parameter will produce a separate result set: use + `nextset()` to read the results of the queries after the first one. + + See :ref:`query-parameters` for all the details about executing + queries. + + .. versionchanged:: 3.1 + + - Added `!returning` parameter to receive query results. + - Performance optimised by making use of the pipeline mode, when + using libpq 14 or newer. + + .. automethod:: copy + + :param statement: The copy operation to execute + :type statement: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed` + :param params: The parameters to pass to the statement, if any. + :type params: Sequence or Mapping + + .. note:: + + The method must be called with:: + + with cursor.copy() as copy: + ... + + See :ref:`copy` for information about :sql:`COPY`. + + .. versionchanged:: 3.1 + Added parameters support. + + .. automethod:: stream + + This command is similar to execute + iter; however it supports endless + data streams. The feature is not available in PostgreSQL, but some + implementations exist: Materialize `TAIL`__ and CockroachDB + `CHANGEFEED`__ for instance. + + The feature, and the API supporting it, are still experimental. + Beware... 👀 + + .. __: https://materialize.com/docs/sql/tail/#main + .. __: https://www.cockroachlabs.com/docs/stable/changefeed-for.html + + The parameters are the same of `execute()`. + + .. warning:: + + Failing to consume the iterator entirely will result in a + connection left in `~psycopg.ConnectionInfo.transaction_status` + `~pq.TransactionStatus.ACTIVE` state: this connection will refuse + to receive further commands (with a message such as *another + command is already in progress*). + + If there is a chance that the generator is not consumed entirely, + in order to restore the connection to a working state you can call + `~generator.close` on the generator object returned by `!stream()`. The + `contextlib.closing` function might be particularly useful to make + sure that `!close()` is called: + + .. code:: + + with closing(cur.stream("select generate_series(1, 10000)")) as gen: + for rec in gen: + something(rec) # might fail + + Without calling `!close()`, in case of error, the connection will + be `!ACTIVE` and unusable. If `!close()` is called, the connection + might be `!INTRANS` or `!INERROR`, depending on whether the server + managed to send the entire resultset to the client. An autocommit + connection will be `!IDLE` instead. + + + .. attribute:: format + + The format of the data returned by the queries. It can be selected + initially e.g. specifying `Connection.cursor`\ `!(binary=True)` and + changed during the cursor's lifetime. It is also possible to override + the value for single queries, e.g. specifying `execute`\ + `!(binary=True)`. + + :type: `pq.Format` + :default: `~pq.Format.TEXT` + + .. seealso:: :ref:`binary-data` + + + .. rubric:: Methods to retrieve results + + Fetch methods are only available if the last operation produced results, + e.g. a :sql:`SELECT` or a command with :sql:`RETURNING`. They will raise + an exception if used with operations that don't return result, such as an + :sql:`INSERT` with no :sql:`RETURNING` or an :sql:`ALTER TABLE`. + + .. note:: + + Cursors are iterable objects, so just using the:: + + for record in cursor: + ... + + syntax will iterate on the records in the current recordset. + + .. autoattribute:: row_factory + + The property affects the objects returned by the `fetchone()`, + `fetchmany()`, `fetchall()` methods. The default + (`~psycopg.rows.tuple_row`) returns a tuple for each record fetched. + + See :ref:`row-factories` for details. + + .. automethod:: fetchone + .. automethod:: fetchmany + .. automethod:: fetchall + .. automethod:: nextset + .. automethod:: scroll + + .. attribute:: pgresult + :type: Optional[psycopg.pq.PGresult] + + The result returned by the last query and currently exposed by the + cursor, if available, else `!None`. + + It can be used to obtain low level info about the last query result + and to access to features not currently wrapped by Psycopg. + + + .. rubric:: Information about the data + + .. autoattribute:: description + + .. autoattribute:: statusmessage + + This is the status tag you typically see in :program:`psql` after + a successful command, such as ``CREATE TABLE`` or ``UPDATE 42``. + + .. autoattribute:: rowcount + .. autoattribute:: rownumber + + .. attribute:: _query + + An helper object used to convert queries and parameters before sending + them to PostgreSQL. + + .. note:: + This attribute is exposed because it might be helpful to debug + problems when the communication between Python and PostgreSQL + doesn't work as expected. For this reason, the attribute is + available when a query fails too. + + .. warning:: + You shouldn't consider it part of the public interface of the + object: it might change without warnings. + + Except this warning, I guess. + + If you would like to build reliable features using this object, + please get in touch so we can try and design an useful interface + for it. + + Among the properties currently exposed by this object: + + - `!query` (`!bytes`): the query effectively sent to PostgreSQL. It + will have Python placeholders (``%s``\-style) replaced with + PostgreSQL ones (``$1``, ``$2``\-style). + + - `!params` (sequence of `!bytes`): the parameters passed to + PostgreSQL, adapted to the database format. + + - `!types` (sequence of `!int`): the OID of the parameters passed to + PostgreSQL. + + - `!formats` (sequence of `pq.Format`): whether the parameter format + is text or binary. + + +The `!ClientCursor` class +------------------------- + +.. seealso:: See :ref:`client-side-binding-cursors` for details. + +.. autoclass:: ClientCursor + + This `Cursor` subclass has exactly the same interface of its parent class, + but, instead of sending query and parameters separately to the server, it + merges them on the client and sends them as a non-parametric query on the + server. This allows, for instance, to execute parametrized data definition + statements and other :ref:`problematic queries <server-side-binding>`. + + .. versionadded:: 3.1 + + .. automethod:: mogrify + + :param query: The query to execute. + :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed` + :param params: The parameters to pass to the query, if any. + :type params: Sequence or Mapping + + +The `!ServerCursor` class +-------------------------- + +.. seealso:: See :ref:`server-side-cursors` for details. + +.. autoclass:: ServerCursor + + This class also implements a `DBAPI-compliant interface`__. It is created + by `Connection.cursor()` specifying the `!name` parameter. Using this + object results in the creation of an equivalent PostgreSQL cursor in the + server. DBAPI-extension methods (such as `~Cursor.copy()` or + `~Cursor.stream()`) are not implemented on this object: use a normal + `Cursor` instead. + + .. __: dbapi-cursor_ + + Most attribute and methods behave exactly like in `Cursor`, here are + documented the differences: + + .. autoattribute:: name + .. autoattribute:: scrollable + + .. seealso:: The PostgreSQL DECLARE_ statement documentation + for the description of :sql:`[NO] SCROLL`. + + .. autoattribute:: withhold + + .. seealso:: The PostgreSQL DECLARE_ statement documentation + for the description of :sql:`{WITH|WITHOUT} HOLD`. + + .. _DECLARE: https://www.postgresql.org/docs/current/sql-declare.html + + + .. automethod:: close + + .. warning:: Closing a server-side cursor is more important than + closing a client-side one because it also releases the resources + on the server, which otherwise might remain allocated until the + end of the session (memory, locks). Using the pattern:: + + with conn.cursor(): + ... + + is especially useful so that the cursor is closed at the end of + the block. + + .. automethod:: execute + + :param query: The query to execute. + :type query: `!str`, `!bytes`, `sql.SQL`, or `sql.Composed` + :param params: The parameters to pass to the query, if any. + :type params: Sequence or Mapping + :param binary: Specify whether the server should return data in binary + format (`!True`) or in text format (`!False`). By default + (`!None`) return data as requested by the cursor's `~Cursor.format`. + + Create a server cursor with given `!name` and the `!query` in argument. + + If using :sql:`DECLARE` is not appropriate (for instance because the + cursor is returned by calling a stored procedure) you can avoid to use + `!execute()`, crete the cursor in other ways, and use directly the + `!fetch*()` methods instead. See :ref:`cursor-steal` for an example. + + Using `!execute()` more than once will close the previous cursor and + open a new one with the same name. + + .. automethod:: executemany + .. automethod:: fetchone + .. automethod:: fetchmany + .. automethod:: fetchall + + These methods use the FETCH_ SQL statement to retrieve some of the + records from the cursor's current position. + + .. _FETCH: https://www.postgresql.org/docs/current/sql-fetch.html + + .. note:: + + You can also iterate on the cursor to read its result one at + time with:: + + for record in cur: + ... + + In this case, the records are not fetched one at time from the + server but they are retrieved in batches of `itersize` to reduce + the number of server roundtrips. + + .. autoattribute:: itersize + + Number of records to fetch at time when iterating on the cursor. The + default is 100. + + .. automethod:: scroll + + This method uses the MOVE_ SQL statement to move the current position + in the server-side cursor, which will affect following `!fetch*()` + operations. If you need to scroll backwards you should probably + call `~Connection.cursor()` using `scrollable=True`. + + Note that PostgreSQL doesn't provide a reliable way to report when a + cursor moves out of bound, so the method might not raise `!IndexError` + when it happens, but it might rather stop at the cursor boundary. + + .. _MOVE: https://www.postgresql.org/docs/current/sql-fetch.html + + +The `!AsyncCursor` class +------------------------ + +.. autoclass:: AsyncCursor + + This class implements a DBAPI-inspired interface, with all the blocking + methods implemented as coroutines. Unless specified otherwise, + non-blocking methods are shared with the `Cursor` class. + + The following methods have the same behaviour of the matching `!Cursor` + methods, but should be called using the `await` keyword. + + .. attribute:: connection + :type: AsyncConnection + + .. automethod:: close + + .. note:: + + You can use:: + + async with conn.cursor(): + ... + + to close the cursor automatically when the block is exited. + + .. automethod:: execute + .. automethod:: executemany + .. automethod:: copy + + .. note:: + + The method must be called with:: + + async with cursor.copy() as copy: + ... + + .. automethod:: stream + + .. note:: + + The method must be called with:: + + async for record in cursor.stream(query): + ... + + .. automethod:: fetchone + .. automethod:: fetchmany + .. automethod:: fetchall + .. automethod:: scroll + + .. note:: + + You can also use:: + + async for record in cursor: + ... + + to iterate on the async cursor results. + + +The `!AsyncClientCursor` class +------------------------------ + +.. autoclass:: AsyncClientCursor + + This class is the `!async` equivalent of the `ClientCursor`. The + difference are the same shown in `AsyncCursor`. + + .. versionadded:: 3.1 + + + +The `!AsyncServerCursor` class +------------------------------ + +.. autoclass:: AsyncServerCursor + + This class implements a DBAPI-inspired interface as the `AsyncCursor` + does, but wraps a server-side cursor like the `ServerCursor` class. It is + created by `AsyncConnection.cursor()` specifying the `!name` parameter. + + The following are the methods exposing a different (async) interface from + the `ServerCursor` counterpart, but sharing the same semantics. + + .. automethod:: close + + .. note:: + You can close the cursor automatically using:: + + async with conn.cursor("name") as cursor: + ... + + .. automethod:: execute + .. automethod:: executemany + .. automethod:: fetchone + .. automethod:: fetchmany + .. automethod:: fetchall + + .. note:: + + You can also iterate on the cursor using:: + + async for record in cur: + ... + + .. automethod:: scroll diff --git a/docs/api/dns.rst b/docs/api/dns.rst new file mode 100644 index 0000000..186bde3 --- /dev/null +++ b/docs/api/dns.rst @@ -0,0 +1,145 @@ +`_dns` -- DNS resolution utilities +================================== + +.. module:: psycopg._dns + +This module contains a few experimental utilities to interact with the DNS +server before performing a connection. + +.. warning:: + This module is experimental and its interface could change in the future, + without warning or respect for the version scheme. It is provided here to + allow experimentation before making it more stable. + +.. warning:: + This module depends on the `dnspython`_ package. The package is currently + not installed automatically as a Psycopg dependency and must be installed + manually: + + .. code:: sh + + $ pip install "dnspython >= 2.1" + + .. _dnspython: https://dnspython.readthedocs.io/ + + +.. function:: resolve_srv(params) + + Apply SRV DNS lookup as defined in :RFC:`2782`. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. + :type params: `!dict` + :return: An updated list of connection parameters. + + For every host defined in the ``params["host"]`` list (comma-separated), + perform SRV lookup if the host is in the form ``_Service._Proto.Target``. + If lookup is successful, return a params dict with hosts and ports replaced + with the looked-up entries. + + Raise `~psycopg.OperationalError` if no lookup is successful and no host + (looked up or unchanged) could be returned. + + In addition to the rules defined by RFC 2782 about the host name pattern, + perform SRV lookup also if the the port is the string ``SRV`` (case + insensitive). + + .. warning:: + This is an experimental functionality. + + .. note:: + One possible way to use this function automatically is to subclass + `~psycopg.Connection`, extending the + `~psycopg.Connection._get_connection_params()` method:: + + import psycopg._dns # not imported automatically + + class SrvCognizantConnection(psycopg.Connection): + @classmethod + def _get_connection_params(cls, conninfo, **kwargs): + params = super()._get_connection_params(conninfo, **kwargs) + params = psycopg._dns.resolve_srv(params) + return params + + # The name will be resolved to db1.example.com + cnn = SrvCognizantConnection.connect("host=_postgres._tcp.db.psycopg.org") + + +.. function:: resolve_srv_async(params) + :async: + + Async equivalent of `resolve_srv()`. + + +.. automethod:: psycopg.Connection._get_connection_params + + .. warning:: + This is an experimental method. + + This method is a subclass hook allowing to manipulate the connection + parameters before performing the connection. Make sure to call the + `!super()` implementation before further manipulation of the arguments:: + + @classmethod + def _get_connection_params(cls, conninfo, **kwargs): + params = super()._get_connection_params(conninfo, **kwargs) + # do something with the params + return params + + +.. automethod:: psycopg.AsyncConnection._get_connection_params + + .. warning:: + This is an experimental method. + + +.. function:: resolve_hostaddr_async(params) + :async: + + Perform async DNS lookup of the hosts and return a new params dict. + + .. deprecated:: 3.1 + The use of this function is not necessary anymore, because + `psycopg.AsyncConnection.connect()` performs non-blocking name + resolution automatically. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. + :type params: `!dict` + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses dynamically. + + The function may change the input ``host``, ``hostname``, ``port`` to allow + connecting without further DNS lookups, eventually removing hosts that are + not resolved, keeping the lists of hosts and ports consistent. + + Raise `~psycopg.OperationalError` if connection is not possible (e.g. no + host resolve, inconsistent lists length). + + See `the PostgreSQL docs`__ for explanation of how these params are used, + and how they support multiple entries. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-PARAMKEYWORDS + + .. warning:: + Before psycopg 3.1, this function doesn't handle the ``/etc/hosts`` file. + + .. note:: + Starting from psycopg 3.1, a similar operation is performed + automatically by `!AsyncConnection._get_connection_params()`, so this + function is unneeded. + + In psycopg 3.0, one possible way to use this function automatically is + to subclass `~psycopg.AsyncConnection`, extending the + `~psycopg.AsyncConnection._get_connection_params()` method:: + + import psycopg._dns # not imported automatically + + class AsyncDnsConnection(psycopg.AsyncConnection): + @classmethod + async def _get_connection_params(cls, conninfo, **kwargs): + params = await super()._get_connection_params(conninfo, **kwargs) + params = await psycopg._dns.resolve_hostaddr_async(params) + return params diff --git a/docs/api/errors.rst b/docs/api/errors.rst new file mode 100644 index 0000000..2fca7c6 --- /dev/null +++ b/docs/api/errors.rst @@ -0,0 +1,540 @@ +`errors` -- Package exceptions +============================== + +.. module:: psycopg.errors + +.. index:: + single: Error; Class + +This module exposes objects to represent and examine database errors. + + +.. currentmodule:: psycopg + +.. index:: + single: Exceptions; DB-API + +.. _dbapi-exceptions: + +DB-API exceptions +----------------- + +In compliance with the DB-API, all the exceptions raised by Psycopg +derive from the following classes: + +.. parsed-literal:: + + `!Exception` + \|__ `Warning` + \|__ `Error` + \|__ `InterfaceError` + \|__ `DatabaseError` + \|__ `DataError` + \|__ `OperationalError` + \|__ `IntegrityError` + \|__ `InternalError` + \|__ `ProgrammingError` + \|__ `NotSupportedError` + +These classes are exposed both by this module and the root `psycopg` module. + +.. autoexception:: Error() + + .. autoattribute:: diag + .. autoattribute:: sqlstate + + The code of the error, if received from the server. + + This attribute is also available as class attribute on the + :ref:`sqlstate-exceptions` classes. + + .. autoattribute:: pgconn + + Most likely it will be in `~psycopg.pq.ConnStatus.BAD` state; + however it might be useful to verify precisely what went wrong, for + instance checking the `~psycopg.pq.PGconn.needs_password` and + `~psycopg.pq.PGconn.used_password` attributes. + + .. versionadded:: 3.1 + + .. autoattribute:: pgresult + + .. versionadded:: 3.1 + + +.. autoexception:: Warning() +.. autoexception:: InterfaceError() +.. autoexception:: DatabaseError() +.. autoexception:: DataError() +.. autoexception:: OperationalError() +.. autoexception:: IntegrityError() +.. autoexception:: InternalError() +.. autoexception:: ProgrammingError() +.. autoexception:: NotSupportedError() + + +Other Psycopg errors +^^^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: psycopg.errors + + +In addition to the standard DB-API errors, Psycopg defines a few more specific +ones. + +.. autoexception:: ConnectionTimeout() +.. autoexception:: PipelineAborted() + + + +.. index:: + single: Exceptions; PostgreSQL + +Error diagnostics +----------------- + +.. autoclass:: Diagnostic() + + The object is available as the `~psycopg.Error`.\ `~psycopg.Error.diag` + attribute and is passed to the callback functions registered with + `~psycopg.Connection.add_notice_handler()`. + + All the information available from the :pq:`PQresultErrorField()` function + are exposed as attributes by the object. For instance the `!severity` + attribute returns the `!PG_DIAG_SEVERITY` code. Please refer to the + PostgreSQL documentation for the meaning of all the attributes. + + The attributes available are: + + .. attribute:: + column_name + constraint_name + context + datatype_name + internal_position + internal_query + message_detail + message_hint + message_primary + schema_name + severity + severity_nonlocalized + source_file + source_function + source_line + sqlstate + statement_position + table_name + + A string with the error field if available; `!None` if not available. + The attribute value is available only for errors sent by the server: + not all the fields are available for all the errors and for all the + server versions. + + +.. _sqlstate-exceptions: + +SQLSTATE exceptions +------------------- + +Errors coming from a database server (as opposite as ones generated +client-side, such as connection failed) usually have a 5-letters error code +called SQLSTATE (available in the `~Diagnostic.sqlstate` attribute of the +error's `~psycopg.Error.diag` attribute). + +Psycopg exposes a different class for each SQLSTATE value, allowing to +write idiomatic error handling code according to specific conditions happening +in the database: + +.. code-block:: python + + try: + cur.execute("LOCK TABLE mytable IN ACCESS EXCLUSIVE MODE NOWAIT") + except psycopg.errors.LockNotAvailable: + locked = True + +The exception names are generated from the PostgreSQL source code and includes +classes for every error defined by PostgreSQL in versions between 9.6 and 15. +Every class in the module is named after what referred as "condition name" `in +the documentation`__, converted to CamelCase: e.g. the error 22012, +``division_by_zero`` is exposed by this module as the class `!DivisionByZero`. +There is a handful of... exceptions to this rule, required for disambiguate +name clashes: please refer to the :ref:`table below <exceptions-list>` for all +the classes defined. + +.. __: https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE + +Every exception class is a subclass of one of the :ref:`standard DB-API +exception <dbapi-exceptions>`, thus exposing the `~psycopg.Error` interface. + +.. versionchanged:: 3.1.4 + Added exceptions introduced in PostgreSQL 15. + +.. autofunction:: lookup + + Example: if you have code using constant names or sql codes you can use + them to look up the exception class. + + .. code-block:: python + + try: + cur.execute("LOCK TABLE mytable IN ACCESS EXCLUSIVE MODE NOWAIT") + except psycopg.errors.lookup("UNDEFINED_TABLE"): + missing = True + except psycopg.errors.lookup("55P03"): + locked = True + + +.. _exceptions-list: + +List of known exceptions +^^^^^^^^^^^^^^^^^^^^^^^^ + +The following are all the SQLSTATE-related error classed defined by this +module, together with the base DBAPI exception they derive from. + +.. autogenerated: start + +========= ================================================== ==================== +SQLSTATE Exception Base exception +========= ================================================== ==================== +**Class 02** - No Data (this is also a warning class per the SQL standard) +--------------------------------------------------------------------------------- +``02000`` `!NoData` `!DatabaseError` +``02001`` `!NoAdditionalDynamicResultSetsReturned` `!DatabaseError` +**Class 03** - SQL Statement Not Yet Complete +--------------------------------------------------------------------------------- +``03000`` `!SqlStatementNotYetComplete` `!DatabaseError` +**Class 08** - Connection Exception +--------------------------------------------------------------------------------- +``08000`` `!ConnectionException` `!OperationalError` +``08001`` `!SqlclientUnableToEstablishSqlconnection` `!OperationalError` +``08003`` `!ConnectionDoesNotExist` `!OperationalError` +``08004`` `!SqlserverRejectedEstablishmentOfSqlconnection` `!OperationalError` +``08006`` `!ConnectionFailure` `!OperationalError` +``08007`` `!TransactionResolutionUnknown` `!OperationalError` +``08P01`` `!ProtocolViolation` `!OperationalError` +**Class 09** - Triggered Action Exception +--------------------------------------------------------------------------------- +``09000`` `!TriggeredActionException` `!DatabaseError` +**Class 0A** - Feature Not Supported +--------------------------------------------------------------------------------- +``0A000`` `!FeatureNotSupported` `!NotSupportedError` +**Class 0B** - Invalid Transaction Initiation +--------------------------------------------------------------------------------- +``0B000`` `!InvalidTransactionInitiation` `!DatabaseError` +**Class 0F** - Locator Exception +--------------------------------------------------------------------------------- +``0F000`` `!LocatorException` `!DatabaseError` +``0F001`` `!InvalidLocatorSpecification` `!DatabaseError` +**Class 0L** - Invalid Grantor +--------------------------------------------------------------------------------- +``0L000`` `!InvalidGrantor` `!DatabaseError` +``0LP01`` `!InvalidGrantOperation` `!DatabaseError` +**Class 0P** - Invalid Role Specification +--------------------------------------------------------------------------------- +``0P000`` `!InvalidRoleSpecification` `!DatabaseError` +**Class 0Z** - Diagnostics Exception +--------------------------------------------------------------------------------- +``0Z000`` `!DiagnosticsException` `!DatabaseError` +``0Z002`` `!StackedDiagnosticsAccessedWithoutActiveHandler` `!DatabaseError` +**Class 20** - Case Not Found +--------------------------------------------------------------------------------- +``20000`` `!CaseNotFound` `!ProgrammingError` +**Class 21** - Cardinality Violation +--------------------------------------------------------------------------------- +``21000`` `!CardinalityViolation` `!ProgrammingError` +**Class 22** - Data Exception +--------------------------------------------------------------------------------- +``22000`` `!DataException` `!DataError` +``22001`` `!StringDataRightTruncation` `!DataError` +``22002`` `!NullValueNoIndicatorParameter` `!DataError` +``22003`` `!NumericValueOutOfRange` `!DataError` +``22004`` `!NullValueNotAllowed` `!DataError` +``22005`` `!ErrorInAssignment` `!DataError` +``22007`` `!InvalidDatetimeFormat` `!DataError` +``22008`` `!DatetimeFieldOverflow` `!DataError` +``22009`` `!InvalidTimeZoneDisplacementValue` `!DataError` +``2200B`` `!EscapeCharacterConflict` `!DataError` +``2200C`` `!InvalidUseOfEscapeCharacter` `!DataError` +``2200D`` `!InvalidEscapeOctet` `!DataError` +``2200F`` `!ZeroLengthCharacterString` `!DataError` +``2200G`` `!MostSpecificTypeMismatch` `!DataError` +``2200H`` `!SequenceGeneratorLimitExceeded` `!DataError` +``2200L`` `!NotAnXmlDocument` `!DataError` +``2200M`` `!InvalidXmlDocument` `!DataError` +``2200N`` `!InvalidXmlContent` `!DataError` +``2200S`` `!InvalidXmlComment` `!DataError` +``2200T`` `!InvalidXmlProcessingInstruction` `!DataError` +``22010`` `!InvalidIndicatorParameterValue` `!DataError` +``22011`` `!SubstringError` `!DataError` +``22012`` `!DivisionByZero` `!DataError` +``22013`` `!InvalidPrecedingOrFollowingSize` `!DataError` +``22014`` `!InvalidArgumentForNtileFunction` `!DataError` +``22015`` `!IntervalFieldOverflow` `!DataError` +``22016`` `!InvalidArgumentForNthValueFunction` `!DataError` +``22018`` `!InvalidCharacterValueForCast` `!DataError` +``22019`` `!InvalidEscapeCharacter` `!DataError` +``2201B`` `!InvalidRegularExpression` `!DataError` +``2201E`` `!InvalidArgumentForLogarithm` `!DataError` +``2201F`` `!InvalidArgumentForPowerFunction` `!DataError` +``2201G`` `!InvalidArgumentForWidthBucketFunction` `!DataError` +``2201W`` `!InvalidRowCountInLimitClause` `!DataError` +``2201X`` `!InvalidRowCountInResultOffsetClause` `!DataError` +``22021`` `!CharacterNotInRepertoire` `!DataError` +``22022`` `!IndicatorOverflow` `!DataError` +``22023`` `!InvalidParameterValue` `!DataError` +``22024`` `!UnterminatedCString` `!DataError` +``22025`` `!InvalidEscapeSequence` `!DataError` +``22026`` `!StringDataLengthMismatch` `!DataError` +``22027`` `!TrimError` `!DataError` +``2202E`` `!ArraySubscriptError` `!DataError` +``2202G`` `!InvalidTablesampleRepeat` `!DataError` +``2202H`` `!InvalidTablesampleArgument` `!DataError` +``22030`` `!DuplicateJsonObjectKeyValue` `!DataError` +``22031`` `!InvalidArgumentForSqlJsonDatetimeFunction` `!DataError` +``22032`` `!InvalidJsonText` `!DataError` +``22033`` `!InvalidSqlJsonSubscript` `!DataError` +``22034`` `!MoreThanOneSqlJsonItem` `!DataError` +``22035`` `!NoSqlJsonItem` `!DataError` +``22036`` `!NonNumericSqlJsonItem` `!DataError` +``22037`` `!NonUniqueKeysInAJsonObject` `!DataError` +``22038`` `!SingletonSqlJsonItemRequired` `!DataError` +``22039`` `!SqlJsonArrayNotFound` `!DataError` +``2203A`` `!SqlJsonMemberNotFound` `!DataError` +``2203B`` `!SqlJsonNumberNotFound` `!DataError` +``2203C`` `!SqlJsonObjectNotFound` `!DataError` +``2203D`` `!TooManyJsonArrayElements` `!DataError` +``2203E`` `!TooManyJsonObjectMembers` `!DataError` +``2203F`` `!SqlJsonScalarRequired` `!DataError` +``2203G`` `!SqlJsonItemCannotBeCastToTargetType` `!DataError` +``22P01`` `!FloatingPointException` `!DataError` +``22P02`` `!InvalidTextRepresentation` `!DataError` +``22P03`` `!InvalidBinaryRepresentation` `!DataError` +``22P04`` `!BadCopyFileFormat` `!DataError` +``22P05`` `!UntranslatableCharacter` `!DataError` +``22P06`` `!NonstandardUseOfEscapeCharacter` `!DataError` +**Class 23** - Integrity Constraint Violation +--------------------------------------------------------------------------------- +``23000`` `!IntegrityConstraintViolation` `!IntegrityError` +``23001`` `!RestrictViolation` `!IntegrityError` +``23502`` `!NotNullViolation` `!IntegrityError` +``23503`` `!ForeignKeyViolation` `!IntegrityError` +``23505`` `!UniqueViolation` `!IntegrityError` +``23514`` `!CheckViolation` `!IntegrityError` +``23P01`` `!ExclusionViolation` `!IntegrityError` +**Class 24** - Invalid Cursor State +--------------------------------------------------------------------------------- +``24000`` `!InvalidCursorState` `!InternalError` +**Class 25** - Invalid Transaction State +--------------------------------------------------------------------------------- +``25000`` `!InvalidTransactionState` `!InternalError` +``25001`` `!ActiveSqlTransaction` `!InternalError` +``25002`` `!BranchTransactionAlreadyActive` `!InternalError` +``25003`` `!InappropriateAccessModeForBranchTransaction` `!InternalError` +``25004`` `!InappropriateIsolationLevelForBranchTransaction` `!InternalError` +``25005`` `!NoActiveSqlTransactionForBranchTransaction` `!InternalError` +``25006`` `!ReadOnlySqlTransaction` `!InternalError` +``25007`` `!SchemaAndDataStatementMixingNotSupported` `!InternalError` +``25008`` `!HeldCursorRequiresSameIsolationLevel` `!InternalError` +``25P01`` `!NoActiveSqlTransaction` `!InternalError` +``25P02`` `!InFailedSqlTransaction` `!InternalError` +``25P03`` `!IdleInTransactionSessionTimeout` `!InternalError` +**Class 26** - Invalid SQL Statement Name +--------------------------------------------------------------------------------- +``26000`` `!InvalidSqlStatementName` `!ProgrammingError` +**Class 27** - Triggered Data Change Violation +--------------------------------------------------------------------------------- +``27000`` `!TriggeredDataChangeViolation` `!OperationalError` +**Class 28** - Invalid Authorization Specification +--------------------------------------------------------------------------------- +``28000`` `!InvalidAuthorizationSpecification` `!OperationalError` +``28P01`` `!InvalidPassword` `!OperationalError` +**Class 2B** - Dependent Privilege Descriptors Still Exist +--------------------------------------------------------------------------------- +``2B000`` `!DependentPrivilegeDescriptorsStillExist` `!InternalError` +``2BP01`` `!DependentObjectsStillExist` `!InternalError` +**Class 2D** - Invalid Transaction Termination +--------------------------------------------------------------------------------- +``2D000`` `!InvalidTransactionTermination` `!InternalError` +**Class 2F** - SQL Routine Exception +--------------------------------------------------------------------------------- +``2F000`` `!SqlRoutineException` `!OperationalError` +``2F002`` `!ModifyingSqlDataNotPermitted` `!OperationalError` +``2F003`` `!ProhibitedSqlStatementAttempted` `!OperationalError` +``2F004`` `!ReadingSqlDataNotPermitted` `!OperationalError` +``2F005`` `!FunctionExecutedNoReturnStatement` `!OperationalError` +**Class 34** - Invalid Cursor Name +--------------------------------------------------------------------------------- +``34000`` `!InvalidCursorName` `!ProgrammingError` +**Class 38** - External Routine Exception +--------------------------------------------------------------------------------- +``38000`` `!ExternalRoutineException` `!OperationalError` +``38001`` `!ContainingSqlNotPermitted` `!OperationalError` +``38002`` `!ModifyingSqlDataNotPermittedExt` `!OperationalError` +``38003`` `!ProhibitedSqlStatementAttemptedExt` `!OperationalError` +``38004`` `!ReadingSqlDataNotPermittedExt` `!OperationalError` +**Class 39** - External Routine Invocation Exception +--------------------------------------------------------------------------------- +``39000`` `!ExternalRoutineInvocationException` `!OperationalError` +``39001`` `!InvalidSqlstateReturned` `!OperationalError` +``39004`` `!NullValueNotAllowedExt` `!OperationalError` +``39P01`` `!TriggerProtocolViolated` `!OperationalError` +``39P02`` `!SrfProtocolViolated` `!OperationalError` +``39P03`` `!EventTriggerProtocolViolated` `!OperationalError` +**Class 3B** - Savepoint Exception +--------------------------------------------------------------------------------- +``3B000`` `!SavepointException` `!OperationalError` +``3B001`` `!InvalidSavepointSpecification` `!OperationalError` +**Class 3D** - Invalid Catalog Name +--------------------------------------------------------------------------------- +``3D000`` `!InvalidCatalogName` `!ProgrammingError` +**Class 3F** - Invalid Schema Name +--------------------------------------------------------------------------------- +``3F000`` `!InvalidSchemaName` `!ProgrammingError` +**Class 40** - Transaction Rollback +--------------------------------------------------------------------------------- +``40000`` `!TransactionRollback` `!OperationalError` +``40001`` `!SerializationFailure` `!OperationalError` +``40002`` `!TransactionIntegrityConstraintViolation` `!OperationalError` +``40003`` `!StatementCompletionUnknown` `!OperationalError` +``40P01`` `!DeadlockDetected` `!OperationalError` +**Class 42** - Syntax Error or Access Rule Violation +--------------------------------------------------------------------------------- +``42000`` `!SyntaxErrorOrAccessRuleViolation` `!ProgrammingError` +``42501`` `!InsufficientPrivilege` `!ProgrammingError` +``42601`` `!SyntaxError` `!ProgrammingError` +``42602`` `!InvalidName` `!ProgrammingError` +``42611`` `!InvalidColumnDefinition` `!ProgrammingError` +``42622`` `!NameTooLong` `!ProgrammingError` +``42701`` `!DuplicateColumn` `!ProgrammingError` +``42702`` `!AmbiguousColumn` `!ProgrammingError` +``42703`` `!UndefinedColumn` `!ProgrammingError` +``42704`` `!UndefinedObject` `!ProgrammingError` +``42710`` `!DuplicateObject` `!ProgrammingError` +``42712`` `!DuplicateAlias` `!ProgrammingError` +``42723`` `!DuplicateFunction` `!ProgrammingError` +``42725`` `!AmbiguousFunction` `!ProgrammingError` +``42803`` `!GroupingError` `!ProgrammingError` +``42804`` `!DatatypeMismatch` `!ProgrammingError` +``42809`` `!WrongObjectType` `!ProgrammingError` +``42830`` `!InvalidForeignKey` `!ProgrammingError` +``42846`` `!CannotCoerce` `!ProgrammingError` +``42883`` `!UndefinedFunction` `!ProgrammingError` +``428C9`` `!GeneratedAlways` `!ProgrammingError` +``42939`` `!ReservedName` `!ProgrammingError` +``42P01`` `!UndefinedTable` `!ProgrammingError` +``42P02`` `!UndefinedParameter` `!ProgrammingError` +``42P03`` `!DuplicateCursor` `!ProgrammingError` +``42P04`` `!DuplicateDatabase` `!ProgrammingError` +``42P05`` `!DuplicatePreparedStatement` `!ProgrammingError` +``42P06`` `!DuplicateSchema` `!ProgrammingError` +``42P07`` `!DuplicateTable` `!ProgrammingError` +``42P08`` `!AmbiguousParameter` `!ProgrammingError` +``42P09`` `!AmbiguousAlias` `!ProgrammingError` +``42P10`` `!InvalidColumnReference` `!ProgrammingError` +``42P11`` `!InvalidCursorDefinition` `!ProgrammingError` +``42P12`` `!InvalidDatabaseDefinition` `!ProgrammingError` +``42P13`` `!InvalidFunctionDefinition` `!ProgrammingError` +``42P14`` `!InvalidPreparedStatementDefinition` `!ProgrammingError` +``42P15`` `!InvalidSchemaDefinition` `!ProgrammingError` +``42P16`` `!InvalidTableDefinition` `!ProgrammingError` +``42P17`` `!InvalidObjectDefinition` `!ProgrammingError` +``42P18`` `!IndeterminateDatatype` `!ProgrammingError` +``42P19`` `!InvalidRecursion` `!ProgrammingError` +``42P20`` `!WindowingError` `!ProgrammingError` +``42P21`` `!CollationMismatch` `!ProgrammingError` +``42P22`` `!IndeterminateCollation` `!ProgrammingError` +**Class 44** - WITH CHECK OPTION Violation +--------------------------------------------------------------------------------- +``44000`` `!WithCheckOptionViolation` `!ProgrammingError` +**Class 53** - Insufficient Resources +--------------------------------------------------------------------------------- +``53000`` `!InsufficientResources` `!OperationalError` +``53100`` `!DiskFull` `!OperationalError` +``53200`` `!OutOfMemory` `!OperationalError` +``53300`` `!TooManyConnections` `!OperationalError` +``53400`` `!ConfigurationLimitExceeded` `!OperationalError` +**Class 54** - Program Limit Exceeded +--------------------------------------------------------------------------------- +``54000`` `!ProgramLimitExceeded` `!OperationalError` +``54001`` `!StatementTooComplex` `!OperationalError` +``54011`` `!TooManyColumns` `!OperationalError` +``54023`` `!TooManyArguments` `!OperationalError` +**Class 55** - Object Not In Prerequisite State +--------------------------------------------------------------------------------- +``55000`` `!ObjectNotInPrerequisiteState` `!OperationalError` +``55006`` `!ObjectInUse` `!OperationalError` +``55P02`` `!CantChangeRuntimeParam` `!OperationalError` +``55P03`` `!LockNotAvailable` `!OperationalError` +``55P04`` `!UnsafeNewEnumValueUsage` `!OperationalError` +**Class 57** - Operator Intervention +--------------------------------------------------------------------------------- +``57000`` `!OperatorIntervention` `!OperationalError` +``57014`` `!QueryCanceled` `!OperationalError` +``57P01`` `!AdminShutdown` `!OperationalError` +``57P02`` `!CrashShutdown` `!OperationalError` +``57P03`` `!CannotConnectNow` `!OperationalError` +``57P04`` `!DatabaseDropped` `!OperationalError` +``57P05`` `!IdleSessionTimeout` `!OperationalError` +**Class 58** - System Error (errors external to PostgreSQL itself) +--------------------------------------------------------------------------------- +``58000`` `!SystemError` `!OperationalError` +``58030`` `!IoError` `!OperationalError` +``58P01`` `!UndefinedFile` `!OperationalError` +``58P02`` `!DuplicateFile` `!OperationalError` +**Class 72** - Snapshot Failure +--------------------------------------------------------------------------------- +``72000`` `!SnapshotTooOld` `!DatabaseError` +**Class F0** - Configuration File Error +--------------------------------------------------------------------------------- +``F0000`` `!ConfigFileError` `!OperationalError` +``F0001`` `!LockFileExists` `!OperationalError` +**Class HV** - Foreign Data Wrapper Error (SQL/MED) +--------------------------------------------------------------------------------- +``HV000`` `!FdwError` `!OperationalError` +``HV001`` `!FdwOutOfMemory` `!OperationalError` +``HV002`` `!FdwDynamicParameterValueNeeded` `!OperationalError` +``HV004`` `!FdwInvalidDataType` `!OperationalError` +``HV005`` `!FdwColumnNameNotFound` `!OperationalError` +``HV006`` `!FdwInvalidDataTypeDescriptors` `!OperationalError` +``HV007`` `!FdwInvalidColumnName` `!OperationalError` +``HV008`` `!FdwInvalidColumnNumber` `!OperationalError` +``HV009`` `!FdwInvalidUseOfNullPointer` `!OperationalError` +``HV00A`` `!FdwInvalidStringFormat` `!OperationalError` +``HV00B`` `!FdwInvalidHandle` `!OperationalError` +``HV00C`` `!FdwInvalidOptionIndex` `!OperationalError` +``HV00D`` `!FdwInvalidOptionName` `!OperationalError` +``HV00J`` `!FdwOptionNameNotFound` `!OperationalError` +``HV00K`` `!FdwReplyHandle` `!OperationalError` +``HV00L`` `!FdwUnableToCreateExecution` `!OperationalError` +``HV00M`` `!FdwUnableToCreateReply` `!OperationalError` +``HV00N`` `!FdwUnableToEstablishConnection` `!OperationalError` +``HV00P`` `!FdwNoSchemas` `!OperationalError` +``HV00Q`` `!FdwSchemaNotFound` `!OperationalError` +``HV00R`` `!FdwTableNotFound` `!OperationalError` +``HV010`` `!FdwFunctionSequenceError` `!OperationalError` +``HV014`` `!FdwTooManyHandles` `!OperationalError` +``HV021`` `!FdwInconsistentDescriptorInformation` `!OperationalError` +``HV024`` `!FdwInvalidAttributeValue` `!OperationalError` +``HV090`` `!FdwInvalidStringLengthOrBufferLength` `!OperationalError` +``HV091`` `!FdwInvalidDescriptorFieldIdentifier` `!OperationalError` +**Class P0** - PL/pgSQL Error +--------------------------------------------------------------------------------- +``P0000`` `!PlpgsqlError` `!ProgrammingError` +``P0001`` `!RaiseException` `!ProgrammingError` +``P0002`` `!NoDataFound` `!ProgrammingError` +``P0003`` `!TooManyRows` `!ProgrammingError` +``P0004`` `!AssertFailure` `!ProgrammingError` +**Class XX** - Internal Error +--------------------------------------------------------------------------------- +``XX000`` `!InternalError_` `!InternalError` +``XX001`` `!DataCorrupted` `!InternalError` +``XX002`` `!IndexCorrupted` `!InternalError` +========= ================================================== ==================== + +.. autogenerated: end + +.. versionadded:: 3.1.4 + Exception `!SqlJsonItemCannotBeCastToTargetType`, introduced in PostgreSQL + 15. diff --git a/docs/api/index.rst b/docs/api/index.rst new file mode 100644 index 0000000..b99550d --- /dev/null +++ b/docs/api/index.rst @@ -0,0 +1,29 @@ +Psycopg 3 API +============= + +.. _api: + +This sections is a reference for all the public objects exposed by the +`psycopg` module. For a more conceptual description you can take a look at +:ref:`basic` and :ref:`advanced`. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + module + connections + cursors + copy + objects + sql + rows + errors + pool + conninfo + adapt + types + abc + pq + crdb + dns diff --git a/docs/api/module.rst b/docs/api/module.rst new file mode 100644 index 0000000..3c3d3c4 --- /dev/null +++ b/docs/api/module.rst @@ -0,0 +1,59 @@ +The `!psycopg` module +===================== + +Psycopg implements the `Python Database DB API 2.0 specification`__. As such +it also exposes the `module-level objects`__ required by the specifications. + +.. __: https://www.python.org/dev/peps/pep-0249/ +.. __: https://www.python.org/dev/peps/pep-0249/#module-interface + +.. module:: psycopg + +.. autofunction:: connect + + This is an alias of the class method `Connection.connect`: see its + documentation for details. + + If you need an asynchronous connection use `AsyncConnection.connect` + instead. + + +.. rubric:: Exceptions + +The standard `DBAPI exceptions`__ are exposed both by the `!psycopg` module +and by the `psycopg.errors` module. The latter also exposes more specific +exceptions, mapping to the database error states (see +:ref:`sqlstate-exceptions`). + +.. __: https://www.python.org/dev/peps/pep-0249/#exceptions + +.. parsed-literal:: + + `!Exception` + \|__ `Warning` + \|__ `Error` + \|__ `InterfaceError` + \|__ `DatabaseError` + \|__ `DataError` + \|__ `OperationalError` + \|__ `IntegrityError` + \|__ `InternalError` + \|__ `ProgrammingError` + \|__ `NotSupportedError` + + +.. data:: adapters + + The default adapters map establishing how Python and PostgreSQL types are + converted into each other. + + This map is used as a template when new connections are created, using + `psycopg.connect()`. Its `~psycopg.adapt.AdaptersMap.types` attribute is a + `~psycopg.types.TypesRegistry` containing information about every + PostgreSQL builtin type, useful for adaptation customisation (see + :ref:`adaptation`):: + + >>> psycopg.adapters.types["int4"] + <TypeInfo: int4 (oid: 23, array oid: 1007)> + + :type: `~psycopg.adapt.AdaptersMap` diff --git a/docs/api/objects.rst b/docs/api/objects.rst new file mode 100644 index 0000000..f085ed9 --- /dev/null +++ b/docs/api/objects.rst @@ -0,0 +1,256 @@ +.. currentmodule:: psycopg + +Other top-level objects +======================= + +Connection information +---------------------- + +.. autoclass:: ConnectionInfo() + + The object is usually returned by `Connection.info`. + + .. autoattribute:: dsn + + .. note:: The `get_parameters()` method returns the same information + as a dict. + + .. autoattribute:: status + + The status can be one of a number of values. However, only two of + these are seen outside of an asynchronous connection procedure: + `~pq.ConnStatus.OK` and `~pq.ConnStatus.BAD`. A good connection to the + database has the status `!OK`. Ordinarily, an `!OK` status will remain + so until `Connection.close()`, but a communications failure might + result in the status changing to `!BAD` prematurely. + + .. autoattribute:: transaction_status + + The status can be `~pq.TransactionStatus.IDLE` (currently idle), + `~pq.TransactionStatus.ACTIVE` (a command is in progress), + `~pq.TransactionStatus.INTRANS` (idle, in a valid transaction block), + or `~pq.TransactionStatus.INERROR` (idle, in a failed transaction + block). `~pq.TransactionStatus.UNKNOWN` is reported if the connection + is bad. `!ACTIVE` is reported only when a query has been sent to the + server and not yet completed. + + .. autoattribute:: pipeline_status + + .. autoattribute:: backend_pid + .. autoattribute:: vendor + + Normally it is `PostgreSQL`; it may be different if connected to + a different database. + + .. versionadded:: 3.1 + + .. autoattribute:: server_version + + The number is formed by converting the major, minor, and revision + numbers into two-decimal-digit numbers and appending them together. + Starting from PostgreSQL 10 the minor version was dropped, so the + second group of digits is always 00. For example, version 9.3.5 is + returned as 90305, version 10.2 as 100002. + + .. autoattribute:: error_message + + .. automethod:: get_parameters + + .. note:: The `dsn` attribute returns the same information in the form + as a string. + + .. autoattribute:: timezone + + .. code:: pycon + + >>> conn.info.timezone + zoneinfo.ZoneInfo(key='Europe/Rome') + + .. autoattribute:: host + + This can be a host name, an IP address, or a directory path if the + connection is via Unix socket. (The path case can be distinguished + because it will always be an absolute path, beginning with ``/``.) + + .. autoattribute:: hostaddr + + Only available if the libpq used is at least from PostgreSQL 12. + Raise `~psycopg.NotSupportedError` otherwise. + + .. autoattribute:: port + .. autoattribute:: dbname + .. autoattribute:: user + .. autoattribute:: password + .. autoattribute:: options + .. automethod:: parameter_status + + Example of parameters are ``server_version``, + ``standard_conforming_strings``... See :pq:`PQparameterStatus()` for + all the available parameters. + + .. autoattribute:: encoding + + The value returned is always normalized to the Python codec + `~codecs.CodecInfo.name`:: + + conn.execute("SET client_encoding TO LATIN9") + conn.info.encoding + 'iso8859-15' + + A few PostgreSQL encodings are not available in Python and cannot be + selected (currently ``EUC_TW``, ``MULE_INTERNAL``). The PostgreSQL + ``SQL_ASCII`` encoding has the special meaning of "no encoding": see + :ref:`adapt-string` for details. + + .. seealso:: + + The `PostgreSQL supported encodings`__. + + .. __: https://www.postgresql.org/docs/current/multibyte.html + + +The description `Column` object +------------------------------- + +.. autoclass:: Column() + + An object describing a column of data from a database result, `as described + by the DBAPI`__, so it can also be unpacked as a 7-items tuple. + + The object is returned by `Cursor.description`. + + .. __: https://www.python.org/dev/peps/pep-0249/#description + + .. autoattribute:: name + .. autoattribute:: type_code + .. autoattribute:: display_size + .. autoattribute:: internal_size + .. autoattribute:: precision + .. autoattribute:: scale + + +Notifications +------------- + +.. autoclass:: Notify() + + The object is usually returned by `Connection.notifies()`. + + .. attribute:: channel + :type: str + + The name of the channel on which the notification was received. + + .. attribute:: payload + :type: str + + The message attached to the notification. + + .. attribute:: pid + :type: int + + The PID of the backend process which sent the notification. + + +Pipeline-related objects +------------------------ + +See :ref:`pipeline-mode` for details. + +.. autoclass:: Pipeline + + This objects is returned by `Connection.pipeline()`. + + .. automethod:: sync + .. automethod:: is_supported + + +.. autoclass:: AsyncPipeline + + This objects is returned by `AsyncConnection.pipeline()`. + + .. automethod:: sync + + +Transaction-related objects +--------------------------- + +See :ref:`transactions` for details about these objects. + +.. autoclass:: IsolationLevel + :members: + + The value is usually used with the `Connection.isolation_level` property. + + Check the PostgreSQL documentation for a description of the effects of the + different `levels of transaction isolation`__. + + .. __: https://www.postgresql.org/docs/current/transaction-iso.html + + +.. autoclass:: Transaction() + + .. autoattribute:: savepoint_name + .. autoattribute:: connection + + +.. autoclass:: AsyncTransaction() + + .. autoattribute:: connection + + +.. autoexception:: Rollback + + It can be used as: + + - ``raise Rollback``: roll back the operation that happened in the current + transaction block and continue the program after the block. + + - ``raise Rollback()``: same effect as above + + - :samp:`raise Rollback({tx})`: roll back any operation that happened in + the `Transaction` `!tx` (returned by a statement such as :samp:`with + conn.transaction() as {tx}:` and all the blocks nested within. The + program will continue after the `!tx` block. + + +Two-Phase Commit related objects +-------------------------------- + +.. autoclass:: Xid() + + See :ref:`two-phase-commit` for details. + + .. autoattribute:: format_id + + Format Identifier of the two-phase transaction. + + .. autoattribute:: gtrid + + Global Transaction Identifier of the two-phase transaction. + + If the Xid doesn't follow the XA standard, it will be the PostgreSQL + ID of the transaction (in which case `format_id` and `bqual` will be + `!None`). + + .. autoattribute:: bqual + + Branch Qualifier of the two-phase transaction. + + .. autoattribute:: prepared + + Timestamp at which the transaction was prepared for commit. + + Only available on transactions recovered by `~Connection.tpc_recover()`. + + .. autoattribute:: owner + + Named of the user that executed the transaction. + + Only available on recovered transactions. + + .. autoattribute:: database + + Named of the database in which the transaction was executed. + + Only available on recovered transactions. diff --git a/docs/api/pool.rst b/docs/api/pool.rst new file mode 100644 index 0000000..76ccc74 --- /dev/null +++ b/docs/api/pool.rst @@ -0,0 +1,331 @@ +`!psycopg_pool` -- Connection pool implementations +================================================== + +.. index:: + double: Connection; Pool + +.. module:: psycopg_pool + +A connection pool is an object used to create and maintain a limited amount of +PostgreSQL connections, reducing the time requested by the program to obtain a +working connection and allowing an arbitrary large number of concurrent +threads or tasks to use a controlled amount of resources on the server. See +:ref:`connection-pools` for more details and usage pattern. + +This package exposes a few connection pool classes: + +- `ConnectionPool` is a synchronous connection pool yielding + `~psycopg.Connection` objects and can be used by multithread applications. + +- `AsyncConnectionPool` has an interface similar to `!ConnectionPool`, but + with `asyncio` functions replacing blocking functions, and yields + `~psycopg.AsyncConnection` instances. + +- `NullConnectionPool` is a `!ConnectionPool` subclass exposing the same + interface of its parent, but not keeping any unused connection in its state. + See :ref:`null-pool` for details about related use cases. + +- `AsyncNullConnectionPool` has the same behaviour of the + `!NullConnectionPool`, but with the same async interface of the + `!AsyncConnectionPool`. + +.. note:: The `!psycopg_pool` package is distributed separately from the main + `psycopg` package: use ``pip install "psycopg[pool]"``, or ``pip install + psycopg_pool``, to make it available. See :ref:`pool-installation`. + + The version numbers indicated in this page refer to the `!psycopg_pool` + package, not to `psycopg`. + + +The `!ConnectionPool` class +--------------------------- + +.. autoclass:: ConnectionPool + + This class implements a connection pool serving `~psycopg.Connection` + instances (or subclasses). The constructor has *alot* of arguments, but + only `!conninfo` and `!min_size` are the fundamental ones, all the other + arguments have meaningful defaults and can probably be tweaked later, if + required. + + :param conninfo: The connection string. See + `~psycopg.Connection.connect()` for details. + :type conninfo: `!str` + + :param min_size: The minimum number of connection the pool will hold. The + pool will actively try to create new connections if some + are lost (closed, broken) and will try to never go below + `!min_size`. + :type min_size: `!int`, default: 4 + + :param max_size: The maximum number of connections the pool will hold. If + `!None`, or equal to `!min_size`, the pool will not grow or + shrink. If larger than `!min_size`, the pool can grow if + more than `!min_size` connections are requested at the same + time and will shrink back after the extra connections have + been unused for more than `!max_idle` seconds. + :type max_size: `!int`, default: `!None` + + :param kwargs: Extra arguments to pass to `!connect()`. Note that this is + *one dict argument* of the pool constructor, which is + expanded as `connect()` keyword parameters. + + :type kwargs: `!dict` + + :param connection_class: The class of the connections to serve. It should + be a `!Connection` subclass. + :type connection_class: `!type`, default: `~psycopg.Connection` + + :param open: If `!True`, open the pool, creating the required connections, + on init. If `!False`, open the pool when `!open()` is called or + when the pool context is entered. See the `open()` method + documentation for more details. + :type open: `!bool`, default: `!True` + + :param configure: A callback to configure a connection after creation. + Useful, for instance, to configure its adapters. If the + connection is used to run internal queries (to inspect the + database) make sure to close an eventual transaction + before leaving the function. + :type configure: `Callable[[Connection], None]` + + :param reset: A callback to reset a function after it has been returned to + the pool. The connection is guaranteed to be passed to the + `!reset()` function in "idle" state (no transaction). When + leaving the `!reset()` function the connection must be left in + *idle* state, otherwise it is discarded. + :type reset: `Callable[[Connection], None]` + + :param name: An optional name to give to the pool, useful, for instance, to + identify it in the logs if more than one pool is used. if not + specified pick a sequential name such as ``pool-1``, + ``pool-2``, etc. + :type name: `!str` + + :param timeout: The default maximum time in seconds that a client can wait + to receive a connection from the pool (using `connection()` + or `getconn()`). Note that these methods allow to override + the `!timeout` default. + :type timeout: `!float`, default: 30 seconds + + :param max_waiting: Maximum number of requests that can be queued to the + pool, after which new requests will fail, raising + `TooManyRequests`. 0 means no queue limit. + :type max_waiting: `!int`, default: 0 + + :param max_lifetime: The maximum lifetime of a connection in the pool, in + seconds. Connections used for longer get closed and + replaced by a new one. The amount is reduced by a + random 10% to avoid mass eviction. + :type max_lifetime: `!float`, default: 1 hour + + :param max_idle: Maximum time, in seconds, that a connection can stay unused + in the pool before being closed, and the pool shrunk. This + only happens to connections more than `!min_size`, if + `!max_size` allowed the pool to grow. + :type max_idle: `!float`, default: 10 minutes + + :param reconnect_timeout: Maximum time, in seconds, the pool will try to + create a connection. If a connection attempt + fails, the pool will try to reconnect a few + times, using an exponential backoff and some + random factor to avoid mass attempts. If repeated + attempts fail, after `!reconnect_timeout` second + the connection attempt is aborted and the + `!reconnect_failed()` callback invoked. + :type reconnect_timeout: `!float`, default: 5 minutes + + :param reconnect_failed: Callback invoked if an attempt to create a new + connection fails for more than `!reconnect_timeout` + seconds. The user may decide, for instance, to + terminate the program (executing `sys.exit()`). + By default don't do anything: restart a new + connection attempt (if the number of connection + fell below `!min_size`). + :type reconnect_failed: ``Callable[[ConnectionPool], None]`` + + :param num_workers: Number of background worker threads used to maintain the + pool state. Background workers are used for example to + create new connections and to clean up connections when + they are returned to the pool. + :type num_workers: `!int`, default: 3 + + .. versionchanged:: 3.1 + + added `!open` parameter to init method. + + .. note:: In a future version, the default value for the `!open` parameter + might be changed to `!False`. If you rely on this behaviour (e.g. if + you don't use the pool as a context manager) you might want to specify + this parameter explicitly. + + .. automethod:: connection + + .. code:: python + + with my_pool.connection() as conn: + conn.execute(...) + + # the connection is now back in the pool + + .. automethod:: open + + .. versionadded:: 3.1 + + + .. automethod:: close + + .. note:: + + The pool can be also used as a context manager, in which case it will + be opened (if necessary) on entering the block and closed on exiting it: + + .. code:: python + + with ConnectionPool(...) as pool: + # code using the pool + + .. automethod:: wait + + .. attribute:: name + :type: str + + The name of the pool set on creation, or automatically generated if not + set. + + .. autoattribute:: min_size + .. autoattribute:: max_size + + The current minimum and maximum size of the pool. Use `resize()` to + change them at runtime. + + .. automethod:: resize + .. automethod:: check + .. automethod:: get_stats + .. automethod:: pop_stats + + See :ref:`pool-stats` for the metrics returned. + + .. rubric:: Functionalities you may not need + + .. automethod:: getconn + .. automethod:: putconn + + +Pool exceptions +--------------- + +.. autoclass:: PoolTimeout() + + Subclass of `~psycopg.OperationalError` + +.. autoclass:: PoolClosed() + + Subclass of `~psycopg.OperationalError` + +.. autoclass:: TooManyRequests() + + Subclass of `~psycopg.OperationalError` + + +The `!AsyncConnectionPool` class +-------------------------------- + +`!AsyncConnectionPool` has a very similar interface to the `ConnectionPool` +class but its blocking methods are implemented as `!async` coroutines. It +returns instances of `~psycopg.AsyncConnection`, or of its subclass if +specified so in the `!connection_class` parameter. + +Only the functions with different signature from `!ConnectionPool` are +listed here. + +.. autoclass:: AsyncConnectionPool + + :param connection_class: The class of the connections to serve. It should + be an `!AsyncConnection` subclass. + :type connection_class: `!type`, default: `~psycopg.AsyncConnection` + + :param configure: A callback to configure a connection after creation. + :type configure: `async Callable[[AsyncConnection], None]` + + :param reset: A callback to reset a function after it has been returned to + the pool. + :type reset: `async Callable[[AsyncConnection], None]` + + .. automethod:: connection + + .. code:: python + + async with my_pool.connection() as conn: + await conn.execute(...) + + # the connection is now back in the pool + + .. automethod:: open + .. automethod:: close + + .. note:: + + The pool can be also used as an async context manager, in which case it + will be opened (if necessary) on entering the block and closed on + exiting it: + + .. code:: python + + async with AsyncConnectionPool(...) as pool: + # code using the pool + + All the other constructor parameters are the same of `!ConnectionPool`. + + .. automethod:: wait + .. automethod:: resize + .. automethod:: check + .. automethod:: getconn + .. automethod:: putconn + + +Null connection pools +--------------------- + +.. versionadded:: 3.1 + +The `NullConnectionPool` is a `ConnectionPool` subclass which doesn't create +connections preemptively and doesn't keep unused connections in its state. See +:ref:`null-pool` for further details. + +The interface of the object is entirely compatible with its parent class. Its +behaviour is similar, with the following differences: + +.. autoclass:: NullConnectionPool + + All the other constructor parameters are the same as in `ConnectionPool`. + + :param min_size: Always 0, cannot be changed. + :type min_size: `!int`, default: 0 + + :param max_size: If None or 0, create a new connection at every request, + without a maximum. If greater than 0, don't create more + than `!max_size` connections and queue the waiting clients. + :type max_size: `!int`, default: None + + :param reset: It is only called when there are waiting clients in the + queue, before giving them a connection already open. If no + client is waiting, the connection is closed and discarded + without a fuss. + :type reset: `Callable[[Connection], None]` + + :param max_idle: Ignored, as null pools don't leave idle connections + sitting around. + + .. automethod:: wait + .. automethod:: resize + .. automethod:: check + + +The `AsyncNullConnectionPool` is, similarly, an `AsyncConnectionPool` subclass +with the same behaviour of the `NullConnectionPool`. + +.. autoclass:: AsyncNullConnectionPool + + The interface is the same of its parent class `AsyncConnectionPool`. The + behaviour is different in the same way described for `NullConnectionPool`. diff --git a/docs/api/pq.rst b/docs/api/pq.rst new file mode 100644 index 0000000..3d9c033 --- /dev/null +++ b/docs/api/pq.rst @@ -0,0 +1,218 @@ +.. _psycopg.pq: + +`pq` -- libpq wrapper module +============================ + +.. index:: + single: libpq + +.. module:: psycopg.pq + +Psycopg is built around the libpq_, the PostgreSQL client library, which +performs most of the network communications and returns query results in C +structures. + +.. _libpq: https://www.postgresql.org/docs/current/libpq.html + +The low-level functions of the library are exposed by the objects in the +`!psycopg.pq` module. + + +.. _pq-impl: + +``pq`` module implementations +----------------------------- + +There are actually several implementations of the module, all offering the +same interface. Current implementations are: + +- ``python``: a pure-python implementation, implemented using the `ctypes` + module. It is less performing than the others, but it doesn't need a C + compiler to install. It requires the libpq installed in the system. + +- ``c``: a C implementation of the libpq wrapper (more precisely, implemented + in Cython_). It is much better performing than the ``python`` + implementation, however it requires development packages installed on the + client machine. It can be installed using the ``c`` extra, i.e. running + ``pip install "psycopg[c]"``. + +- ``binary``: a pre-compiled C implementation, bundled with all the required + libraries. It is the easiest option to deal with, fast to install and it + should require no development tool or client library, however it may be not + available for every platform. You can install it using the ``binary`` extra, + i.e. running ``pip install "psycopg[binary]"``. + +.. _Cython: https://cython.org/ + +The implementation currently used is available in the `~psycopg.pq.__impl__` +module constant. + +At import time, Psycopg 3 will try to use the best implementation available +and will fail if none is usable. You can force the use of a specific +implementation by exporting the env var :envvar:`PSYCOPG_IMPL`: importing the +library will fail if the requested implementation is not available:: + + $ PSYCOPG_IMPL=c python -c "import psycopg" + Traceback (most recent call last): + ... + ImportError: couldn't import requested psycopg 'c' implementation: No module named 'psycopg_c' + + +Module content +-------------- + +.. autodata:: __impl__ + + The choice of implementation is automatic but can be forced setting the + :envvar:`PSYCOPG_IMPL` env var. + + +.. autofunction:: version + + .. seealso:: the :pq:`PQlibVersion()` function + + +.. autodata:: __build_version__ + +.. autofunction:: error_message + + +Objects wrapping libpq structures and functions +----------------------------------------------- + +.. admonition:: TODO + + finish documentation + +.. autoclass:: PGconn() + + .. autoattribute:: pgconn_ptr + .. automethod:: get_cancel + .. autoattribute:: needs_password + .. autoattribute:: used_password + + .. automethod:: encrypt_password + + .. code:: python + + >>> enc = conn.info.encoding + >>> encrypted = conn.pgconn.encrypt_password(password.encode(enc), rolename.encode(enc)) + b'SCRAM-SHA-256$4096:... + + .. automethod:: trace + .. automethod:: set_trace_flags + .. automethod:: untrace + + .. code:: python + + >>> conn.pgconn.trace(sys.stderr.fileno()) + >>> conn.pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + >>> conn.execute("select now()") + F 13 Parse "" "BEGIN" 0 + F 14 Bind "" "" 0 0 1 0 + F 6 Describe P "" + F 9 Execute "" 0 + F 4 Sync + B 4 ParseComplete + B 4 BindComplete + B 4 NoData + B 10 CommandComplete "BEGIN" + B 5 ReadyForQuery T + F 17 Query "select now()" + B 28 RowDescription 1 "now" NNNN 0 NNNN 8 -1 0 + B 39 DataRow 1 29 '2022-09-14 14:12:16.648035+02' + B 13 CommandComplete "SELECT 1" + B 5 ReadyForQuery T + <psycopg.Cursor [TUPLES_OK] [INTRANS] (database=postgres) at 0x7f18a18ba040> + >>> conn.pgconn.untrace() + + +.. autoclass:: PGresult() + + .. autoattribute:: pgresult_ptr + + +.. autoclass:: Conninfo +.. autoclass:: Escaping + +.. autoclass:: PGcancel() + :members: + + +Enumerations +------------ + +.. autoclass:: ConnStatus + :members: + + There are other values in this enum, but only `OK` and `BAD` are seen + after a connection has been established. Other statuses might only be seen + during the connection phase and are considered internal. + + .. seealso:: :pq:`PQstatus()` returns this value. + + +.. autoclass:: PollingStatus + :members: + + .. seealso:: :pq:`PQconnectPoll` for a description of these states. + + +.. autoclass:: TransactionStatus + :members: + + .. seealso:: :pq:`PQtransactionStatus` for a description of these states. + + +.. autoclass:: ExecStatus + :members: + + .. seealso:: :pq:`PQresultStatus` for a description of these states. + + +.. autoclass:: PipelineStatus + :members: + + .. seealso:: :pq:`PQpipelineStatus` for a description of these states. + + +.. autoclass:: Format + :members: + + +.. autoclass:: DiagnosticField + + Available attributes: + + .. attribute:: + SEVERITY + SEVERITY_NONLOCALIZED + SQLSTATE + MESSAGE_PRIMARY + MESSAGE_DETAIL + MESSAGE_HINT + STATEMENT_POSITION + INTERNAL_POSITION + INTERNAL_QUERY + CONTEXT + SCHEMA_NAME + TABLE_NAME + COLUMN_NAME + DATATYPE_NAME + CONSTRAINT_NAME + SOURCE_FILE + SOURCE_LINE + SOURCE_FUNCTION + + .. seealso:: :pq:`PQresultErrorField` for a description of these values. + + +.. autoclass:: Ping + :members: + + .. seealso:: :pq:`PQpingParams` for a description of these values. + +.. autoclass:: Trace + :members: + + .. seealso:: :pq:`PQsetTraceFlags` for a description of these values. diff --git a/docs/api/rows.rst b/docs/api/rows.rst new file mode 100644 index 0000000..204f1ea --- /dev/null +++ b/docs/api/rows.rst @@ -0,0 +1,74 @@ +.. _psycopg.rows: + +`rows` -- row factory implementations +===================================== + +.. module:: psycopg.rows + +The module exposes a few generic `~psycopg.RowFactory` implementation, which +can be used to retrieve data from the database in more complex structures than +the basic tuples. + +Check out :ref:`row-factories` for information about how to use these objects. + +.. autofunction:: tuple_row +.. autofunction:: dict_row +.. autofunction:: namedtuple_row +.. autofunction:: class_row + + This is not a row factory, but rather a factory of row factories. + Specifying `!row_factory=class_row(MyClass)` will create connections and + cursors returning `!MyClass` objects on fetch. + + Example:: + + from dataclasses import dataclass + import psycopg + from psycopg.rows import class_row + + @dataclass + class Person: + first_name: str + last_name: str + age: int = None + + conn = psycopg.connect() + cur = conn.cursor(row_factory=class_row(Person)) + + cur.execute("select 'John' as first_name, 'Smith' as last_name").fetchone() + # Person(first_name='John', last_name='Smith', age=None) + +.. autofunction:: args_row +.. autofunction:: kwargs_row + + +Formal rows protocols +--------------------- + +These objects can be used to describe your own rows adapter for static typing +checks, such as mypy_. + +.. _mypy: https://mypy.readthedocs.io/ + + +.. autoclass:: psycopg.rows.RowMaker() + + .. method:: __call__(values: Sequence[Any]) -> Row + + Convert a sequence of values from the database to a finished object. + + +.. autoclass:: psycopg.rows.RowFactory() + + .. method:: __call__(cursor: Cursor[Row]) -> RowMaker[Row] + + Inspect the result on a cursor and return a `RowMaker` to convert rows. + +.. autoclass:: psycopg.rows.AsyncRowFactory() + +.. autoclass:: psycopg.rows.BaseRowFactory() + +Note that it's easy to implement an object implementing both `!RowFactory` and +`!AsyncRowFactory`: usually, everything you need to implement a row factory is +to access the cursor's `~psycopg.Cursor.description`, which is provided by +both the cursor flavours. diff --git a/docs/api/sql.rst b/docs/api/sql.rst new file mode 100644 index 0000000..6959fee --- /dev/null +++ b/docs/api/sql.rst @@ -0,0 +1,151 @@ +`sql` -- SQL string composition +=============================== + +.. index:: + double: Binding; Client-Side + +.. module:: psycopg.sql + +The module contains objects and functions useful to generate SQL dynamically, +in a convenient and safe way. SQL identifiers (e.g. names of tables and +fields) cannot be passed to the `~psycopg.Cursor.execute()` method like query +arguments:: + + # This will not work + table_name = 'my_table' + cur.execute("INSERT INTO %s VALUES (%s, %s)", [table_name, 10, 20]) + +The SQL query should be composed before the arguments are merged, for +instance:: + + # This works, but it is not optimal + table_name = 'my_table' + cur.execute( + "INSERT INTO %s VALUES (%%s, %%s)" % table_name, + [10, 20]) + +This sort of works, but it is an accident waiting to happen: the table name +may be an invalid SQL literal and need quoting; even more serious is the +security problem in case the table name comes from an untrusted source. The +name should be escaped using `~psycopg.pq.Escaping.escape_identifier()`:: + + from psycopg.pq import Escaping + + # This works, but it is not optimal + table_name = 'my_table' + cur.execute( + "INSERT INTO %s VALUES (%%s, %%s)" % Escaping.escape_identifier(table_name), + [10, 20]) + +This is now safe, but it somewhat ad-hoc. In case, for some reason, it is +necessary to include a value in the query string (as opposite as in a value) +the merging rule is still different. It is also still relatively dangerous: if +`!escape_identifier()` is forgotten somewhere, the program will usually work, +but will eventually crash in the presence of a table or field name with +containing characters to escape, or will present a potentially exploitable +weakness. + +The objects exposed by the `!psycopg.sql` module allow generating SQL +statements on the fly, separating clearly the variable parts of the statement +from the query parameters:: + + from psycopg import sql + + cur.execute( + sql.SQL("INSERT INTO {} VALUES (%s, %s)") + .format(sql.Identifier('my_table')), + [10, 20]) + + +Module usage +------------ + +Usually you should express the template of your query as an `SQL` instance +with ``{}``\-style placeholders and use `~SQL.format()` to merge the variable +parts into them, all of which must be `Composable` subclasses. You can still +have ``%s``\-style placeholders in your query and pass values to +`~psycopg.Cursor.execute()`: such value placeholders will be untouched by +`!format()`:: + + query = sql.SQL("SELECT {field} FROM {table} WHERE {pkey} = %s").format( + field=sql.Identifier('my_name'), + table=sql.Identifier('some_table'), + pkey=sql.Identifier('id')) + +The resulting object is meant to be passed directly to cursor methods such as +`~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`, +`~psycopg.Cursor.copy()`, but can also be used to compose a query as a Python +string, using the `~Composable.as_string()` method:: + + cur.execute(query, (42,)) + full_query = query.as_string(cur) + +If part of your query is a variable sequence of arguments, such as a +comma-separated list of field names, you can use the `SQL.join()` method to +pass them to the query:: + + query = sql.SQL("SELECT {fields} FROM {table}").format( + fields=sql.SQL(',').join([ + sql.Identifier('field1'), + sql.Identifier('field2'), + sql.Identifier('field3'), + ]), + table=sql.Identifier('some_table')) + + +`!sql` objects +-------------- + +The `!sql` objects are in the following inheritance hierarchy: + +| `Composable`: the base class exposing the common interface +| ``|__`` `SQL`: a literal snippet of an SQL query +| ``|__`` `Identifier`: a PostgreSQL identifier or dot-separated sequence of identifiers +| ``|__`` `Literal`: a value hardcoded into a query +| ``|__`` `Placeholder`: a `%s`\ -style placeholder whose value will be added later e.g. by `~psycopg.Cursor.execute()` +| ``|__`` `Composed`: a sequence of `!Composable` instances. + + +.. autoclass:: Composable() + + .. automethod:: as_bytes + .. automethod:: as_string + + +.. autoclass:: SQL + + .. versionchanged:: 3.1 + + The input object should be a `~typing.LiteralString`. See :pep:`675` + for details. + + .. automethod:: format + + .. automethod:: join + + +.. autoclass:: Identifier + +.. autoclass:: Literal + + .. versionchanged:: 3.1 + Add a type cast to the representation if useful in ambiguous context + (e.g. ``'2000-01-01'::date``) + +.. autoclass:: Placeholder + +.. autoclass:: Composed + + .. automethod:: join + + +Utility functions +----------------- + +.. autofunction:: quote + +.. data:: + NULL + DEFAULT + + `sql.SQL` objects often useful in queries. diff --git a/docs/api/types.rst b/docs/api/types.rst new file mode 100644 index 0000000..f04659e --- /dev/null +++ b/docs/api/types.rst @@ -0,0 +1,168 @@ +.. currentmodule:: psycopg.types + +.. _psycopg.types: + +`!types` -- Types information and adapters +========================================== + +.. module:: psycopg.types + +The `!psycopg.types` package exposes: + +- objects to describe PostgreSQL types, such as `TypeInfo`, `TypesRegistry`, + to help or :ref:`customise the types conversion <adaptation>`; + +- concrete implementations of `~psycopg.abc.Loader` and `~psycopg.abc.Dumper` + protocols to :ref:`handle builtin data types <types-adaptation>`; + +- helper objects to represent PostgreSQL data types which :ref:`don't have a + straightforward Python representation <extra-adaptation>`, such as + `~range.Range`. + + +Types information +----------------- + +The `TypeInfo` object describes simple information about a PostgreSQL data +type, such as its name, oid and array oid. `!TypeInfo` subclasses may hold more +information, for instance the components of a composite type. + +You can use `TypeInfo.fetch()` to query information from a database catalog, +which is then used by helper functions, such as +`~psycopg.types.hstore.register_hstore()`, to register adapters on types whose +OID is not known upfront or to create more specialised adapters. + +The `!TypeInfo` object doesn't instruct Psycopg to convert a PostgreSQL type +into a Python type: this is the role of a `~psycopg.abc.Loader`. However it +can extend the behaviour of other adapters: if you create a loader for +`!MyType`, using the `TypeInfo` information, Psycopg will be able to manage +seamlessly arrays of `!MyType` or ranges and composite types using `!MyType` +as a subtype. + +.. seealso:: :ref:`adaptation` describes how to convert from Python objects to + PostgreSQL types and back. + +.. code:: python + + from psycopg.adapt import Loader + from psycopg.types import TypeInfo + + t = TypeInfo.fetch(conn, "mytype") + t.register(conn) + + for record in conn.execute("SELECT mytypearray FROM mytable"): + # records will return lists of "mytype" as string + + class MyTypeLoader(Loader): + def load(self, data): + # parse the data and return a MyType instance + + conn.adapters.register_loader("mytype", MyTypeLoader) + + for record in conn.execute("SELECT mytypearray FROM mytable"): + # records will return lists of MyType instances + + +.. autoclass:: TypeInfo + + .. method:: fetch(conn, name) + :classmethod: + + .. method:: fetch(aconn, name) + :classmethod: + :async: + :noindex: + + Query a system catalog to read information about a type. + + :param conn: the connection to query + :type conn: ~psycopg.Connection or ~psycopg.AsyncConnection + :param name: the name of the type to query. It can include a schema + name. + :type name: `!str` or `~psycopg.sql.Identifier` + :return: a `!TypeInfo` object (or subclass) populated with the type + information, `!None` if not found. + + If the connection is async, `!fetch()` will behave as a coroutine and + the caller will need to `!await` on it to get the result:: + + t = await TypeInfo.fetch(aconn, "mytype") + + .. automethod:: register + + :param context: the context where the type is registered, for instance + a `~psycopg.Connection` or `~psycopg.Cursor`. `!None` registers + the `!TypeInfo` globally. + :type context: Optional[~psycopg.abc.AdaptContext] + + Registering the `TypeInfo` in a context allows the adapters of that + context to look up type information: for instance it allows to + recognise automatically arrays of that type and load them from the + database as a list of the base type. + + +In order to get information about dynamic PostgreSQL types, Psycopg offers a +few `!TypeInfo` subclasses, whose `!fetch()` method can extract more complete +information about the type, such as `~psycopg.types.composite.CompositeInfo`, +`~psycopg.types.range.RangeInfo`, `~psycopg.types.multirange.MultirangeInfo`, +`~psycopg.types.enum.EnumInfo`. + +`!TypeInfo` objects are collected in `TypesRegistry` instances, which help type +information lookup. Every `~psycopg.adapt.AdaptersMap` exposes its type map on +its `~psycopg.adapt.AdaptersMap.types` attribute. + +.. autoclass:: TypesRegistry + + `!TypeRegistry` instances are typically exposed by + `~psycopg.adapt.AdaptersMap` objects in adapt contexts such as + `~psycopg.Connection` or `~psycopg.Cursor` (e.g. `!conn.adapters.types`). + + The global registry, from which the others inherit from, is available as + `psycopg.adapters`\ `!.types`. + + .. automethod:: __getitem__ + + .. code:: python + + >>> import psycopg + + >>> psycopg.adapters.types["text"] + <TypeInfo: text (oid: 25, array oid: 1009)> + + >>> psycopg.adapters.types[23] + <TypeInfo: int4 (oid: 23, array oid: 1007)> + + .. automethod:: get + + .. automethod:: get_oid + + .. code:: python + + >>> psycopg.adapters.types.get_oid("text[]") + 1009 + + .. automethod:: get_by_subtype + + +.. _json-adapters: + +JSON adapters +------------- + +See :ref:`adapt-json` for details. + +.. currentmodule:: psycopg.types.json + +.. autoclass:: Json +.. autoclass:: Jsonb + +Wrappers to signal to convert `!obj` to a json or jsonb PostgreSQL value. + +Any object supported by the underlying `!dumps()` function can be wrapped. + +If a `!dumps` function is passed to the wrapper, use it to dump the wrapped +object. Otherwise use the function specified by `set_json_dumps()`. + + +.. autofunction:: set_json_dumps +.. autofunction:: set_json_loads diff --git a/docs/basic/adapt.rst b/docs/basic/adapt.rst new file mode 100644 index 0000000..1538327 --- /dev/null +++ b/docs/basic/adapt.rst @@ -0,0 +1,522 @@ +.. currentmodule:: psycopg + +.. index:: + single: Adaptation + pair: Objects; Adaptation + single: Data types; Adaptation + +.. _types-adaptation: + +Adapting basic Python types +=========================== + +Many standard Python types are adapted into SQL and returned as Python +objects when a query is executed. + +Converting the following data types between Python and PostgreSQL works +out-of-the-box and doesn't require any configuration. In case you need to +customise the conversion you should take a look at :ref:`adaptation`. + + +.. index:: + pair: Boolean; Adaptation + +.. _adapt-bool: + +Booleans adaptation +------------------- + +Python `bool` values `!True` and `!False` are converted to the equivalent +`PostgreSQL boolean type`__:: + + >>> cur.execute("SELECT %s, %s", (True, False)) + # equivalent to "SELECT true, false" + +.. __: https://www.postgresql.org/docs/current/datatype-boolean.html + + +.. index:: + single: Adaptation; numbers + single: Integer; Adaptation + single: Float; Adaptation + single: Decimal; Adaptation + +.. _adapt-numbers: + +Numbers adaptation +------------------ + +.. seealso:: + + - `PostgreSQL numeric types + <https://www.postgresql.org/docs/current/static/datatype-numeric.html>`__ + +- Python `int` values can be converted to PostgreSQL :sql:`smallint`, + :sql:`integer`, :sql:`bigint`, or :sql:`numeric`, according to their numeric + value. Psycopg will choose the smallest data type available, because + PostgreSQL can automatically cast a type up (e.g. passing a `smallint` where + PostgreSQL expect an `integer` is gladly accepted) but will not cast down + automatically (e.g. if a function has an :sql:`integer` argument, passing it + a :sql:`bigint` value will fail, even if the value is 1). + +- Python `float` values are converted to PostgreSQL :sql:`float8`. + +- Python `~decimal.Decimal` values are converted to PostgreSQL :sql:`numeric`. + +On the way back, smaller types (:sql:`int2`, :sql:`int4`, :sql:`float4`) are +promoted to the larger Python counterpart. + +.. note:: + + Sometimes you may prefer to receive :sql:`numeric` data as `!float` + instead, for performance reason or ease of manipulation: you can configure + an adapter to :ref:`cast PostgreSQL numeric to Python float + <adapt-example-float>`. This of course may imply a loss of precision. + + +.. index:: + pair: Strings; Adaptation + single: Unicode; Adaptation + pair: Encoding; SQL_ASCII + +.. _adapt-string: + +Strings adaptation +------------------ + +.. seealso:: + + - `PostgreSQL character types + <https://www.postgresql.org/docs/current/datatype-character.html>`__ + +Python `str` are converted to PostgreSQL string syntax, and PostgreSQL types +such as :sql:`text` and :sql:`varchar` are converted back to Python `!str`: + +.. code:: python + + conn = psycopg.connect() + conn.execute( + "INSERT INTO menu (id, entry) VALUES (%s, %s)", + (1, "Crème Brûlée at 4.99€")) + conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0] + 'Crème Brûlée at 4.99€' + +PostgreSQL databases `have an encoding`__, and `the session has an encoding`__ +too, exposed in the `!Connection.info.`\ `~ConnectionInfo.encoding` +attribute. If your database and connection are in UTF-8 encoding you will +likely have no problem, otherwise you will have to make sure that your +application only deals with the non-ASCII chars that the database can handle; +failing to do so may result in encoding/decoding errors: + +.. __: https://www.postgresql.org/docs/current/sql-createdatabase.html +.. __: https://www.postgresql.org/docs/current/multibyte.html + +.. code:: python + + # The encoding is set at connection time according to the db configuration + conn.info.encoding + 'utf-8' + + # The Latin-9 encoding can manage some European accented letters + # and the Euro symbol + conn.execute("SET client_encoding TO LATIN9") + conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0] + 'Crème Brûlée at 4.99€' + + # The Latin-1 encoding doesn't have a representation for the Euro symbol + conn.execute("SET client_encoding TO LATIN1") + conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0] + # Traceback (most recent call last) + # ... + # UntranslatableCharacter: character with byte sequence 0xe2 0x82 0xac + # in encoding "UTF8" has no equivalent in encoding "LATIN1" + +In rare cases you may have strings with unexpected encodings in the database. +Using the ``SQL_ASCII`` client encoding will disable decoding of the data +coming from the database, which will be returned as `bytes`: + +.. code:: python + + conn.execute("SET client_encoding TO SQL_ASCII") + conn.execute("SELECT entry FROM menu WHERE id = 1").fetchone()[0] + b'Cr\xc3\xa8me Br\xc3\xbbl\xc3\xa9e at 4.99\xe2\x82\xac' + +Alternatively you can cast the unknown encoding data to :sql:`bytea` to +retrieve it as bytes, leaving other strings unaltered: see :ref:`adapt-binary` + +Note that PostgreSQL text cannot contain the ``0x00`` byte. If you need to +store Python strings that may contain binary zeros you should use a +:sql:`bytea` field. + + +.. index:: + single: bytea; Adaptation + single: bytes; Adaptation + single: bytearray; Adaptation + single: memoryview; Adaptation + single: Binary string + +.. _adapt-binary: + +Binary adaptation +----------------- + +Python types representing binary objects (`bytes`, `bytearray`, `memoryview`) +are converted by default to :sql:`bytea` fields. By default data received is +returned as `!bytes`. + +If you are storing large binary data in bytea fields (such as binary documents +or images) you should probably use the binary format to pass and return +values, otherwise binary data will undergo `ASCII escaping`__, taking some CPU +time and more bandwidth. See :ref:`binary-data` for details. + +.. __: https://www.postgresql.org/docs/current/datatype-binary.html + + +.. _adapt-date: + +Date/time types adaptation +-------------------------- + +.. seealso:: + + - `PostgreSQL date/time types + <https://www.postgresql.org/docs/current/datatype-datetime.html>`__ + +- Python `~datetime.date` objects are converted to PostgreSQL :sql:`date`. +- Python `~datetime.datetime` objects are converted to PostgreSQL + :sql:`timestamp` (if they don't have a `!tzinfo` set) or :sql:`timestamptz` + (if they do). +- Python `~datetime.time` objects are converted to PostgreSQL :sql:`time` + (if they don't have a `!tzinfo` set) or :sql:`timetz` (if they do). +- Python `~datetime.timedelta` objects are converted to PostgreSQL + :sql:`interval`. + +PostgreSQL :sql:`timestamptz` values are returned with a timezone set to the +`connection TimeZone setting`__, which is available as a Python +`~zoneinfo.ZoneInfo` object in the `!Connection.info`.\ `~ConnectionInfo.timezone` +attribute:: + + >>> conn.info.timezone + zoneinfo.ZoneInfo(key='Europe/London') + + >>> conn.execute("select '2048-07-08 12:00'::timestamptz").fetchone()[0] + datetime.datetime(2048, 7, 8, 12, 0, tzinfo=zoneinfo.ZoneInfo(key='Europe/London')) + +.. note:: + PostgreSQL :sql:`timestamptz` doesn't store "a timestamp with a timezone + attached": it stores a timestamp always in UTC, which is converted, on + output, to the connection TimeZone setting:: + + >>> conn.execute("SET TIMEZONE to 'Europe/Rome'") # UTC+2 in summer + + >>> conn.execute("SELECT '2042-07-01 12:00Z'::timestamptz").fetchone()[0] # UTC input + datetime.datetime(2042, 7, 1, 14, 0, tzinfo=zoneinfo.ZoneInfo(key='Europe/Rome')) + + Check out the `PostgreSQL documentation about timezones`__ for all the + details. + + .. __: https://www.postgresql.org/docs/current/datatype-datetime.html + #DATATYPE-TIMEZONES + +.. __: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-TIMEZONE + + +.. _adapt-json: + +JSON adaptation +--------------- + +Psycopg can map between Python objects and PostgreSQL `json/jsonb +types`__, allowing to customise the load and dump function used. + +.. __: https://www.postgresql.org/docs/current/datatype-json.html + +Because several Python objects could be considered JSON (dicts, lists, +scalars, even date/time if using a dumps function customised to use them), +Psycopg requires you to wrap the object to dump as JSON into a wrapper: +either `psycopg.types.json.Json` or `~psycopg.types.json.Jsonb`. + +.. code:: python + + from psycopg.types.json import Jsonb + + thing = {"foo": ["bar", 42]} + conn.execute("INSERT INTO mytable VALUES (%s)", [Jsonb(thing)]) + +By default Psycopg uses the standard library `json.dumps` and `json.loads` +functions to serialize and de-serialize Python objects to JSON. If you want to +customise how serialization happens, for instance changing serialization +parameters or using a different JSON library, you can specify your own +functions using the `psycopg.types.json.set_json_dumps()` and +`~psycopg.types.json.set_json_loads()` functions, to apply either globally or +to a specific context (connection or cursor). + +.. code:: python + + from functools import partial + from psycopg.types.json import Jsonb, set_json_dumps, set_json_loads + import ujson + + # Use a faster dump function + set_json_dumps(ujson.dumps) + + # Return floating point values as Decimal, just in one connection + set_json_loads(partial(json.loads, parse_float=Decimal), conn) + + conn.execute("SELECT %s", [Jsonb({"value": 123.45})]).fetchone()[0] + # {'value': Decimal('123.45')} + +If you need an even more specific dump customisation only for certain objects +(including different configurations in the same query) you can specify a +`!dumps` parameter in the +`~psycopg.types.json.Json`/`~psycopg.types.json.Jsonb` wrapper, which will +take precedence over what is specified by `!set_json_dumps()`. + +.. code:: python + + from uuid import UUID, uuid4 + + class UUIDEncoder(json.JSONEncoder): + """A JSON encoder which can dump UUID.""" + def default(self, obj): + if isinstance(obj, UUID): + return str(obj) + return json.JSONEncoder.default(self, obj) + + uuid_dumps = partial(json.dumps, cls=UUIDEncoder) + obj = {"uuid": uuid4()} + cnn.execute("INSERT INTO objs VALUES %s", [Json(obj, dumps=uuid_dumps)]) + # will insert: {'uuid': '0a40799d-3980-4c65-8315-2956b18ab0e1'} + + +.. _adapt-list: + +Lists adaptation +---------------- + +Python `list` objects are adapted to `PostgreSQL arrays`__ and back. Only +lists containing objects of the same type can be dumped to PostgreSQL (but the +list may contain `!None` elements). + +.. __: https://www.postgresql.org/docs/current/arrays.html + +.. note:: + + If you have a list of values which you want to use with the :sql:`IN` + operator... don't. It won't work (neither with a list nor with a tuple):: + + >>> conn.execute("SELECT * FROM mytable WHERE id IN %s", [[10,20,30]]) + Traceback (most recent call last): + File "<stdin>", line 1, in <module> + psycopg.errors.SyntaxError: syntax error at or near "$1" + LINE 1: SELECT * FROM mytable WHERE id IN $1 + ^ + + What you want to do instead is to use the `'= ANY()' expression`__ and pass + the values as a list (not a tuple). + + >>> conn.execute("SELECT * FROM mytable WHERE id = ANY(%s)", [[10,20,30]]) + + This has also the advantage of working with an empty list, whereas ``IN + ()`` is not valid SQL. + + .. __: https://www.postgresql.org/docs/current/functions-comparisons.html + #id-1.5.8.30.16 + + +.. _adapt-uuid: + +UUID adaptation +--------------- + +Python `uuid.UUID` objects are adapted to PostgreSQL `UUID type`__ and back:: + + >>> conn.execute("select gen_random_uuid()").fetchone()[0] + UUID('97f0dd62-3bd2-459e-89b8-a5e36ea3c16c') + + >>> from uuid import uuid4 + >>> conn.execute("select gen_random_uuid() = %s", [uuid4()]).fetchone()[0] + False # long shot + +.. __: https://www.postgresql.org/docs/current/datatype-uuid.html + + +.. _adapt-network: + +Network data types adaptation +----------------------------- + +Objects from the `ipaddress` module are converted to PostgreSQL `network +address types`__: + +- `~ipaddress.IPv4Address`, `~ipaddress.IPv4Interface` objects are converted + to the PostgreSQL :sql:`inet` type. On the way back, :sql:`inet` values + indicating a single address are converted to `!IPv4Address`, otherwise they + are converted to `!IPv4Interface` + +- `~ipaddress.IPv4Network` objects are converted to the :sql:`cidr` type and + back. + +- `~ipaddress.IPv6Address`, `~ipaddress.IPv6Interface`, + `~ipaddress.IPv6Network` objects follow the same rules, with IPv6 + :sql:`inet` and :sql:`cidr` values. + +.. __: https://www.postgresql.org/docs/current/datatype-net-types.html#DATATYPE-CIDR + +.. code:: python + + >>> conn.execute("select '192.168.0.1'::inet, '192.168.0.1/24'::inet").fetchone() + (IPv4Address('192.168.0.1'), IPv4Interface('192.168.0.1/24')) + + >>> conn.execute("select '::ffff:1.2.3.0/120'::cidr").fetchone()[0] + IPv6Network('::ffff:102:300/120') + + +.. _adapt-enum: + +Enum adaptation +--------------- + +.. versionadded:: 3.1 + +Psycopg can adapt Python `~enum.Enum` subclasses into PostgreSQL enum types +(created with the |CREATE TYPE AS ENUM|_ command). + +.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)` +.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html + +In order to set up a bidirectional enum mapping, you should get information +about the PostgreSQL enum using the `~types.enum.EnumInfo` class and +register it using `~types.enum.register_enum()`. The behaviour of unregistered +and registered enums is different. + +- If the enum is not registered with `register_enum()`: + + - Pure `!Enum` classes are dumped as normal strings, using their member + names as value. The unknown oid is used, so PostgreSQL should be able to + use this string in most contexts (such as an enum or a text field). + + .. versionchanged:: 3.1 + In previous version dumping pure enums is not supported and raise a + "cannot adapt" error. + + - Mix-in enums are dumped according to their mix-in type (because a `class + MyIntEnum(int, Enum)` is more specifically an `!int` than an `!Enum`, so + it's dumped by default according to `!int` rules). + + - PostgreSQL enums are loaded as Python strings. If you want to load arrays + of such enums you will have to find their OIDs using `types.TypeInfo.fetch()` + and register them using `~types.TypeInfo.register()`. + +- If the enum is registered (using `~types.enum.EnumInfo`\ `!.fetch()` and + `~types.enum.register_enum()`): + + - Enums classes, both pure and mixed-in, are dumped by name. + + - The registered PostgreSQL enum is loaded back as the registered Python + enum members. + +.. autoclass:: psycopg.types.enum.EnumInfo + + `!EnumInfo` is a subclass of `~psycopg.types.TypeInfo`: refer to the + latter's documentation for generic usage, especially the + `~psycopg.types.TypeInfo.fetch()` method. + + .. attribute:: labels + + After `~psycopg.types.TypeInfo.fetch()`, it contains the labels defined + in the PostgreSQL enum type. + + .. attribute:: enum + + After `register_enum()` is called, it will contain the Python type + mapping to the registered enum. + +.. autofunction:: psycopg.types.enum.register_enum + + After registering, fetching data of the registered enum will cast + PostgreSQL enum labels into corresponding Python enum members. + + If no `!enum` is specified, a new `Enum` is created based on + PostgreSQL enum labels. + +Example:: + + >>> from enum import Enum, auto + >>> from psycopg.types.enum import EnumInfo, register_enum + + >>> class UserRole(Enum): + ... ADMIN = auto() + ... EDITOR = auto() + ... GUEST = auto() + + >>> conn.execute("CREATE TYPE user_role AS ENUM ('ADMIN', 'EDITOR', 'GUEST')") + + >>> info = EnumInfo.fetch(conn, "user_role") + >>> register_enum(info, conn, UserRole) + + >>> some_editor = info.enum.EDITOR + >>> some_editor + <UserRole.EDITOR: 2> + + >>> conn.execute( + ... "SELECT pg_typeof(%(editor)s), %(editor)s", + ... {"editor": some_editor} + ... ).fetchone() + ('user_role', <UserRole.EDITOR: 2>) + + >>> conn.execute( + ... "SELECT ARRAY[%s, %s]", + ... [UserRole.ADMIN, UserRole.GUEST] + ... ).fetchone() + [<UserRole.ADMIN: 1>, <UserRole.GUEST: 3>] + +If the Python and the PostgreSQL enum don't match 1:1 (for instance if members +have a different name, or if more than one Python enum should map to the same +PostgreSQL enum, or vice versa), you can specify the exceptions using the +`!mapping` parameter. + +`!mapping` should be a dictionary with Python enum members as keys and the +matching PostgreSQL enum labels as values, or a list of `(member, label)` +pairs with the same meaning (useful when some members are repeated). Order +matters: if an element on either side is specified more than once, the last +pair in the sequence will take precedence:: + + # Legacy roles, defined in medieval times. + >>> conn.execute( + ... "CREATE TYPE abbey_role AS ENUM ('ABBOT', 'SCRIBE', 'MONK', 'GUEST')") + + >>> info = EnumInfo.fetch(conn, "abbey_role") + >>> register_enum(info, conn, UserRole, mapping=[ + ... (UserRole.ADMIN, "ABBOT"), + ... (UserRole.EDITOR, "SCRIBE"), + ... (UserRole.EDITOR, "MONK")]) + + >>> conn.execute("SELECT '{ABBOT,SCRIBE,MONK,GUEST}'::abbey_role[]").fetchone()[0] + [<UserRole.ADMIN: 1>, + <UserRole.EDITOR: 2>, + <UserRole.EDITOR: 2>, + <UserRole.GUEST: 3>] + + >>> conn.execute("SELECT %s::text[]", [list(UserRole)]).fetchone()[0] + ['ABBOT', 'MONK', 'GUEST'] + +A particularly useful case is when the PostgreSQL labels match the *values* of +a `!str`\-based Enum. In this case it is possible to use something like ``{m: +m.value for m in enum}`` as mapping:: + + >>> class LowercaseRole(str, Enum): + ... ADMIN = "admin" + ... EDITOR = "editor" + ... GUEST = "guest" + + >>> conn.execute( + ... "CREATE TYPE lowercase_role AS ENUM ('admin', 'editor', 'guest')") + + >>> info = EnumInfo.fetch(conn, "lowercase_role") + >>> register_enum( + ... info, conn, LowercaseRole, mapping={m: m.value for m in LowercaseRole}) + + >>> conn.execute("SELECT 'editor'::lowercase_role").fetchone()[0] + <LowercaseRole.EDITOR: 'editor'> diff --git a/docs/basic/copy.rst b/docs/basic/copy.rst new file mode 100644 index 0000000..2bb4498 --- /dev/null +++ b/docs/basic/copy.rst @@ -0,0 +1,212 @@ +.. currentmodule:: psycopg + +.. index:: + pair: COPY; SQL command + +.. _copy: + +Using COPY TO and COPY FROM +=========================== + +Psycopg allows to operate with `PostgreSQL COPY protocol`__. :sql:`COPY` is +one of the most efficient ways to load data into the database (and to modify +it, with some SQL creativity). + +.. __: https://www.postgresql.org/docs/current/sql-copy.html + +Copy is supported using the `Cursor.copy()` method, passing it a query of the +form :sql:`COPY ... FROM STDIN` or :sql:`COPY ... TO STDOUT`, and managing the +resulting `Copy` object in a `!with` block: + +.. code:: python + + with cursor.copy("COPY table_name (col1, col2) FROM STDIN") as copy: + # pass data to the 'copy' object using write()/write_row() + +You can compose a COPY statement dynamically by using objects from the +`psycopg.sql` module: + +.. code:: python + + with cursor.copy( + sql.SQL("COPY {} TO STDOUT").format(sql.Identifier("table_name")) + ) as copy: + # read data from the 'copy' object using read()/read_row() + +.. versionchanged:: 3.1 + + You can also pass parameters to `!copy()`, like in `~Cursor.execute()`: + + .. code:: python + + with cur.copy("COPY (SELECT * FROM table_name LIMIT %s) TO STDOUT", (3,)) as copy: + # expect no more than three records + +The connection is subject to the usual transaction behaviour, so, unless the +connection is in autocommit, at the end of the COPY operation you will still +have to commit the pending changes and you can still roll them back. See +:ref:`transactions` for details. + + +.. _copy-in-row: + +Writing data row-by-row +----------------------- + +Using a copy operation you can load data into the database from any Python +iterable (a list of tuples, or any iterable of sequences): the Python values +are adapted as they would be in normal querying. To perform such operation use +a :sql:`COPY ... FROM STDIN` with `Cursor.copy()` and use `~Copy.write_row()` +on the resulting object in a `!with` block. On exiting the block the +operation will be concluded: + +.. code:: python + + records = [(10, 20, "hello"), (40, None, "world")] + + with cursor.copy("COPY sample (col1, col2, col3) FROM STDIN") as copy: + for record in records: + copy.write_row(record) + +If an exception is raised inside the block, the operation is interrupted and +the records inserted so far are discarded. + +In order to read or write from `!Copy` row-by-row you must not specify +:sql:`COPY` options such as :sql:`FORMAT CSV`, :sql:`DELIMITER`, :sql:`NULL`: +please leave these details alone, thank you :) + + +.. _copy-out-row: + +Reading data row-by-row +----------------------- + +You can also do the opposite, reading rows out of a :sql:`COPY ... TO STDOUT` +operation, by iterating on `~Copy.rows()`. However this is not something you +may want to do normally: usually the normal query process will be easier to +use. + +PostgreSQL, currently, doesn't give complete type information on :sql:`COPY +TO`, so the rows returned will have unparsed data, as strings or bytes, +according to the format. + +.. code:: python + + with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy: + for row in copy.rows(): + print(row) # return unparsed data: ('10', '2046-12-24') + +You can improve the results by using `~Copy.set_types()` before reading, but +you have to specify them yourself. + +.. code:: python + + with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy: + copy.set_types(["int4", "date"]) + for row in copy.rows(): + print(row) # (10, datetime.date(2046, 12, 24)) + + +.. _copy-block: + +Copying block-by-block +---------------------- + +If data is already formatted in a way suitable for copy (for instance because +it is coming from a file resulting from a previous `COPY TO` operation) it can +be loaded into the database using `Copy.write()` instead. + +.. code:: python + + with open("data", "r") as f: + with cursor.copy("COPY data FROM STDIN") as copy: + while data := f.read(BLOCK_SIZE): + copy.write(data) + +In this case you can use any :sql:`COPY` option and format, as long as the +input data is compatible with what the operation in `!copy()` expects. Data +can be passed as `!str`, if the copy is in :sql:`FORMAT TEXT`, or as `!bytes`, +which works with both :sql:`FORMAT TEXT` and :sql:`FORMAT BINARY`. + +In order to produce data in :sql:`COPY` format you can use a :sql:`COPY ... TO +STDOUT` statement and iterate over the resulting `Copy` object, which will +produce a stream of `!bytes` objects: + +.. code:: python + + with open("data.out", "wb") as f: + with cursor.copy("COPY table_name TO STDOUT") as copy: + for data in copy: + f.write(data) + + +.. _copy-binary: + +Binary copy +----------- + +Binary copy is supported by specifying :sql:`FORMAT BINARY` in the :sql:`COPY` +statement. In order to import binary data using `~Copy.write_row()`, all the +types passed to the database must have a binary dumper registered; this is not +necessary if the data is copied :ref:`block-by-block <copy-block>` using +`~Copy.write()`. + +.. warning:: + + PostgreSQL is particularly finicky when loading data in binary mode and + will apply **no cast rules**. This means, for example, that passing the + value 100 to an `integer` column **will fail**, because Psycopg will pass + it as a `smallint` value, and the server will reject it because its size + doesn't match what expected. + + You can work around the problem using the `~Copy.set_types()` method of + the `!Copy` object and specifying carefully the types to load. + +.. seealso:: See :ref:`binary-data` for further info about binary querying. + + +.. _copy-async: + +Asynchronous copy support +------------------------- + +Asynchronous operations are supported using the same patterns as above, using +the objects obtained by an `AsyncConnection`. For instance, if `!f` is an +object supporting an asynchronous `!read()` method returning :sql:`COPY` data, +a fully-async copy operation could be: + +.. code:: python + + async with cursor.copy("COPY data FROM STDIN") as copy: + while data := await f.read(): + await copy.write(data) + +The `AsyncCopy` object documentation describes the signature of the +asynchronous methods and the differences from its sync `Copy` counterpart. + +.. seealso:: See :ref:`async` for further info about using async objects. + + +Example: copying a table across servers +--------------------------------------- + +In order to copy a table, or a portion of a table, across servers, you can use +two COPY operations on two different connections, reading from the first and +writing to the second. + +.. code:: python + + with psycopg.connect(dsn_src) as conn1, psycopg.connect(dsn_tgt) as conn2: + with conn1.cursor().copy("COPY src TO STDOUT (FORMAT BINARY)") as copy1: + with conn2.cursor().copy("COPY tgt FROM STDIN (FORMAT BINARY)") as copy2: + for data in copy1: + copy2.write(data) + +Using :sql:`FORMAT BINARY` usually gives a performance boost, but it only +works if the source and target schema are *perfectly identical*. If the tables +are only *compatible* (for example, if you are copying an :sql:`integer` field +into a :sql:`bigint` destination field) you should omit the `BINARY` option and +perform a text-based copy. See :ref:`copy-binary` for details. + +The same pattern can be adapted to use :ref:`async objects <async>` in order +to perform an :ref:`async copy <copy-async>`. diff --git a/docs/basic/from_pg2.rst b/docs/basic/from_pg2.rst new file mode 100644 index 0000000..0692049 --- /dev/null +++ b/docs/basic/from_pg2.rst @@ -0,0 +1,359 @@ +.. index:: + pair: psycopg2; Differences + +.. currentmodule:: psycopg + +.. _from-psycopg2: + + +Differences from `!psycopg2` +============================ + +Psycopg 3 uses the common DBAPI structure of many other database adapters and +tries to behave as close as possible to `!psycopg2`. There are however a few +differences to be aware of. + +.. tip:: + Most of the times, the workarounds suggested here will work with both + Psycopg 2 and 3, which could be useful if you are porting a program or + writing a program that should work with both Psycopg 2 and 3. + + +.. _server-side-binding: + +Server-side binding +------------------- + +Psycopg 3 sends the query and the parameters to the server separately, instead +of merging them on the client side. Server-side binding works for normal +:sql:`SELECT` and data manipulation statements (:sql:`INSERT`, :sql:`UPDATE`, +:sql:`DELETE`), but it doesn't work with many other statements. For instance, +it doesn't work with :sql:`SET` or with :sql:`NOTIFY`:: + + >>> conn.execute("SET TimeZone TO %s", ["UTC"]) + Traceback (most recent call last): + ... + psycopg.errors.SyntaxError: syntax error at or near "$1" + LINE 1: SET TimeZone TO $1 + ^ + + >>> conn.execute("NOTIFY %s, %s", ["chan", 42]) + Traceback (most recent call last): + ... + psycopg.errors.SyntaxError: syntax error at or near "$1" + LINE 1: NOTIFY $1, $2 + ^ + +and with any data definition statement:: + + >>> conn.execute("CREATE TABLE foo (id int DEFAULT %s)", [42]) + Traceback (most recent call last): + ... + psycopg.errors.UndefinedParameter: there is no parameter $1 + LINE 1: CREATE TABLE foo (id int DEFAULT $1) + ^ + +Sometimes, PostgreSQL offers an alternative: for instance the `set_config()`__ +function can be used instead of the :sql:`SET` statement, the `pg_notify()`__ +function can be used instead of :sql:`NOTIFY`:: + + >>> conn.execute("SELECT set_config('TimeZone', %s, false)", ["UTC"]) + + >>> conn.execute("SELECT pg_notify(%s, %s)", ["chan", "42"]) + +.. __: https://www.postgresql.org/docs/current/functions-admin.html + #FUNCTIONS-ADMIN-SET + +.. __: https://www.postgresql.org/docs/current/sql-notify.html + #id-1.9.3.157.7.5 + +If this is not possible, you must merge the query and the parameter on the +client side. You can do so using the `psycopg.sql` objects:: + + >>> from psycopg import sql + + >>> cur.execute(sql.SQL("CREATE TABLE foo (id int DEFAULT {})").format(42)) + +or creating a :ref:`client-side binding cursor <client-side-binding-cursors>` +such as `ClientCursor`:: + + >>> cur = ClientCursor(conn) + >>> cur.execute("CREATE TABLE foo (id int DEFAULT %s)", [42]) + +If you need `!ClientCursor` often, you can set the `Connection.cursor_factory` +to have them created by default by `Connection.cursor()`. This way, Psycopg 3 +will behave largely the same way of Psycopg 2. + +Note that, both server-side and client-side, you can only specify **values** +as parameters (i.e. *the strings that go in single quotes*). If you need to +parametrize different parts of a statement (such as a table name), you must +use the `psycopg.sql` module:: + + >>> from psycopg import sql + + # This will quote the user and the password using the right quotes + # e.g.: ALTER USER "foo" SET PASSWORD 'bar' + >>> conn.execute( + ... sql.SQL("ALTER USER {} SET PASSWORD {}") + ... .format(sql.Identifier(username), password)) + + +.. _multi-statements: + +Multiple statements in the same query +------------------------------------- + +As a consequence of using :ref:`server-side bindings <server-side-binding>`, +when parameters are used, it is not possible to execute several statements in +the same `!execute()` call, separating them by semicolon:: + + >>> conn.execute( + ... "INSERT INTO foo VALUES (%s); INSERT INTO foo VALUES (%s)", + ... (10, 20)) + Traceback (most recent call last): + ... + psycopg.errors.SyntaxError: cannot insert multiple commands into a prepared statement + +One obvious way to work around the problem is to use several `!execute()` +calls. + +**There is no such limitation if no parameters are used**. As a consequence, you +can compose a multiple query on the client side and run them all in the same +`!execute()` call, using the `psycopg.sql` objects:: + + >>> from psycopg import sql + >>> conn.execute( + ... sql.SQL("INSERT INTO foo VALUES ({}); INSERT INTO foo values ({})" + ... .format(10, 20)) + +or a :ref:`client-side binding cursor <client-side-binding-cursors>`:: + + >>> cur = psycopg.ClientCursor(conn) + >>> cur.execute( + ... "INSERT INTO foo VALUES (%s); INSERT INTO foo VALUES (%s)", + ... (10, 20)) + +.. warning:: + + If a statements must be executed outside a transaction (such as + :sql:`CREATE DATABASE`), it cannot be executed in batch with other + statements, even if the connection is in autocommit mode:: + + >>> conn.autocommit = True + >>> conn.execute("CREATE DATABASE foo; SELECT 1") + Traceback (most recent call last): + ... + psycopg.errors.ActiveSqlTransaction: CREATE DATABASE cannot run inside a transaction block + + This happens because PostgreSQL itself will wrap multiple statements in a + transaction. Note that your will experience a different behaviour in + :program:`psql` (:program:`psql` will split the queries on semicolons and + send them to the server separately). + + This is not new in Psycopg 3: the same limitation is present in + `!psycopg2` too. + + +.. _multi-results: + +Multiple results returned from multiple statements +-------------------------------------------------- + +If more than one statement returning results is executed in psycopg2, only the +result of the last statement is returned:: + + >>> cur_pg2.execute("SELECT 1; SELECT 2") + >>> cur_pg2.fetchone() + (2,) + +In Psycopg 3 instead, all the results are available. After running the query, +the first result will be readily available in the cursor and can be consumed +using the usual `!fetch*()` methods. In order to access the following +results, you can use the `Cursor.nextset()` method:: + + >>> cur_pg3.execute("SELECT 1; SELECT 2") + >>> cur_pg3.fetchone() + (1,) + + >>> cur_pg3.nextset() + True + >>> cur_pg3.fetchone() + (2,) + + >>> cur_pg3.nextset() + None # no more results + +Remember though that you cannot use server-side bindings to :ref:`execute more +than one statement in the same query <multi-statements>`. + + +.. _difference-cast-rules: + +Different cast rules +-------------------- + +In rare cases, especially around variadic functions, PostgreSQL might fail to +find a function candidate for the given data types:: + + >>> conn.execute("SELECT json_build_array(%s, %s)", ["foo", "bar"]) + Traceback (most recent call last): + ... + psycopg.errors.IndeterminateDatatype: could not determine data type of parameter $1 + +This can be worked around specifying the argument types explicitly via a cast:: + + >>> conn.execute("SELECT json_build_array(%s::text, %s::text)", ["foo", "bar"]) + + +.. _in-and-tuple: + +You cannot use ``IN %s`` with a tuple +------------------------------------- + +``IN`` cannot be used with a tuple as single parameter, as was possible with +``psycopg2``:: + + >>> conn.execute("SELECT * FROM foo WHERE id IN %s", [(10,20,30)]) + Traceback (most recent call last): + ... + psycopg.errors.SyntaxError: syntax error at or near "$1" + LINE 1: SELECT * FROM foo WHERE id IN $1 + ^ + +What you can do is to use the `= ANY()`__ construct and pass the candidate +values as a list instead of a tuple, which will be adapted to a PostgreSQL +array:: + + >>> conn.execute("SELECT * FROM foo WHERE id = ANY(%s)", [[10,20,30]]) + +Note that `ANY()` can be used with `!psycopg2` too, and has the advantage of +accepting an empty list of values too as argument, which is not supported by +the :sql:`IN` operator instead. + +.. __: https://www.postgresql.org/docs/current/functions-comparisons.html + #id-1.5.8.30.16 + + +.. _diff-adapt: + +Different adaptation system +--------------------------- + +The adaptation system has been completely rewritten, in order to address +server-side parameters adaptation, but also to consider performance, +flexibility, ease of customization. + +The default behaviour with builtin data should be :ref:`what you would expect +<types-adaptation>`. If you have customised the way to adapt data, or if you +are managing your own extension types, you should look at the :ref:`new +adaptation system <adaptation>`. + +.. seealso:: + + - :ref:`types-adaptation` for the basic behaviour. + - :ref:`adaptation` for more advanced use. + + +.. _diff-copy: + +Copy is no longer file-based +---------------------------- + +`!psycopg2` exposes :ref:`a few copy methods <pg2:copy>` to interact with +PostgreSQL :sql:`COPY`. Their file-based interface doesn't make it easy to load +dynamically-generated data into a database. + +There is now a single `~Cursor.copy()` method, which is similar to +`!psycopg2` `!copy_expert()` in accepting a free-form :sql:`COPY` command and +returns an object to read/write data, block-wise or record-wise. The different +usage pattern also enables :sql:`COPY` to be used in async interactions. + +.. seealso:: See :ref:`copy` for the details. + + +.. _diff-with: + +`!with` connection +------------------ + +In `!psycopg2`, using the syntax :ref:`with connection <pg2:with>`, +only the transaction is closed, not the connection. This behaviour is +surprising for people used to several other Python classes wrapping resources, +such as files. + +In Psycopg 3, using :ref:`with connection <with-connection>` will close the +connection at the end of the `!with` block, making handling the connection +resources more familiar. + +In order to manage transactions as blocks you can use the +`Connection.transaction()` method, which allows for finer control, for +instance to use nested transactions. + +.. seealso:: See :ref:`transaction-context` for details. + + +.. _diff-callproc: + +`!callproc()` is gone +--------------------- + +`cursor.callproc()` is not implemented. The method has a simplistic semantic +which doesn't account for PostgreSQL positional parameters, procedures, +set-returning functions... Use a normal `~Cursor.execute()` with :sql:`SELECT +function_name(...)` or :sql:`CALL procedure_name(...)` instead. + + +.. _diff-client-encoding: + +`!client_encoding` is gone +-------------------------- + +Psycopg automatically uses the database client encoding to decode data to +Unicode strings. Use `ConnectionInfo.encoding` if you need to read the +encoding. You can select an encoding at connection time using the +`!client_encoding` connection parameter and you can change the encoding of a +connection by running a :sql:`SET client_encoding` statement... But why would +you? + + +.. _infinity-datetime: + +No default infinity dates handling +---------------------------------- + +PostgreSQL can represent a much wider range of dates and timestamps than +Python. While Python dates are limited to the years between 1 and 9999 +(represented by constants such as `datetime.date.min` and +`~datetime.date.max`), PostgreSQL dates extend to BC dates and past the year +10K. Furthermore PostgreSQL can also represent symbolic dates "infinity", in +both directions. + +In psycopg2, by default, `infinity dates and timestamps map to 'date.max'`__ +and similar constants. This has the problem of creating a non-bijective +mapping (two Postgres dates, infinity and 9999-12-31, both map to the same +Python date). There is also the perversity that valid Postgres dates, greater +than Python `!date.max` but arguably lesser than infinity, will still +overflow. + +In Psycopg 3, every date greater than year 9999 will overflow, including +infinity. If you would like to customize this mapping (for instance flattening +every date past Y10K on `!date.max`) you can subclass and adapt the +appropriate loaders: take a look at :ref:`this example +<adapt-example-inf-date>` to see how. + +.. __: https://www.psycopg.org/docs/usage.html#infinite-dates-handling + + +.. _whats-new: + +What's new in Psycopg 3 +----------------------- + +- :ref:`Asynchronous support <async>` +- :ref:`Server-side parameters binding <server-side-binding>` +- :ref:`Prepared statements <prepared-statements>` +- :ref:`Binary communication <binary-data>` +- :ref:`Python-based COPY support <copy>` +- :ref:`Support for static typing <static-typing>` +- :ref:`A redesigned connection pool <connection-pools>` +- :ref:`Direct access to the libpq functionalities <psycopg.pq>` diff --git a/docs/basic/index.rst b/docs/basic/index.rst new file mode 100644 index 0000000..bf9e27d --- /dev/null +++ b/docs/basic/index.rst @@ -0,0 +1,26 @@ +.. _basic: + +Getting started with Psycopg 3 +============================== + +This section of the documentation will explain :ref:`how to install Psycopg +<installation>` and how to perform normal activities such as :ref:`querying +the database <usage>` or :ref:`loading data using COPY <copy>`. + +.. important:: + + If you are familiar with psycopg2 please take a look at + :ref:`from-psycopg2` to see what is changed. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + install + usage + params + adapt + pgtypes + transactions + copy + from_pg2 diff --git a/docs/basic/install.rst b/docs/basic/install.rst new file mode 100644 index 0000000..8e1dc6d --- /dev/null +++ b/docs/basic/install.rst @@ -0,0 +1,172 @@ +.. _installation: + +Installation +============ + +In short, if you use a :ref:`supported system<supported-systems>`:: + + pip install --upgrade pip # upgrade pip to at least 20.3 + pip install "psycopg[binary]" + +and you should be :ref:`ready to start <module-usage>`. Read further for +alternative ways to install. + + +.. _supported-systems: + +Supported systems +----------------- + +The Psycopg version documented here has *official and tested* support for: + +- Python: from version 3.7 to 3.11 + + - Python 3.6 supported before Psycopg 3.1 + +- PostgreSQL: from version 10 to 15 +- OS: Linux, macOS, Windows + +The tests to verify the supported systems run in `Github workflows`__: +anything that is not tested there is not officially supported. This includes: + +.. __: https://github.com/psycopg/psycopg/actions + +- Unofficial Python distributions such as Conda; +- Alternative PostgreSQL implementation; +- macOS hardware and releases not available on Github workflows. + +If you use an unsupported system, things might work (because, for instance, the +database may use the same wire protocol as PostgreSQL) but we cannot guarantee +the correct working or a smooth ride. + + +.. _binary-install: + +Binary installation +------------------- + +The quickest way to start developing with Psycopg 3 is to install the binary +packages by running:: + + pip install "psycopg[binary]" + +This will install a self-contained package with all the libraries needed. +**You will need pip 20.3 at least**: please run ``pip install --upgrade pip`` +to update it beforehand. + +The above package should work in most situations. It **will not work** in +some cases though. + +If your platform is not supported you should proceed to a :ref:`local +installation <local-installation>` or a :ref:`pure Python installation +<pure-python-installation>`. + +.. seealso:: + + Did Psycopg 3 install ok? Great! You can now move on to the :ref:`basic + module usage <module-usage>` to learn how it works. + + Keep on reading if the above method didn't work and you need a different + way to install Psycopg 3. + + For further information about the differences between the packages see + :ref:`pq-impl`. + + +.. _local-installation: + +Local installation +------------------ + +A "Local installation" results in a performing and maintainable library. The +library will include the speed-up C module and will be linked to the system +libraries (``libpq``, ``libssl``...) so that system upgrade of libraries will +upgrade the libraries used by Psycopg 3 too. This is the preferred way to +install Psycopg for a production site. + +In order to perform a local installation you need some prerequisites: + +- a C compiler, +- Python development headers (e.g. the ``python3-dev`` package). +- PostgreSQL client development headers (e.g. the ``libpq-dev`` package). +- The :program:`pg_config` program available in the :envvar:`PATH`. + +You **must be able** to troubleshoot an extension build, for instance you must +be able to read your compiler's error message. If you are not, please don't +try this and follow the `binary installation`_ instead. + +If your build prerequisites are in place you can run:: + + pip install "psycopg[c]" + + +.. _pure-python-installation: + +Pure Python installation +------------------------ + +If you simply install:: + + pip install psycopg + +without ``[c]`` or ``[binary]`` extras you will obtain a pure Python +implementation. This is particularly handy to debug and hack, but it still +requires the system libpq to operate (which will be imported dynamically via +`ctypes`). + +In order to use the pure Python installation you will need the ``libpq`` +installed in the system: for instance on Debian system you will probably +need:: + + sudo apt install libpq5 + +.. note:: + + The ``libpq`` is the client library used by :program:`psql`, the + PostgreSQL command line client, to connect to the database. On most + systems, installing :program:`psql` will install the ``libpq`` too as a + dependency. + +If you are not able to fulfill this requirement please follow the `binary +installation`_. + + +.. _pool-installation: + +Installing the connection pool +------------------------------ + +The :ref:`Psycopg connection pools <connection-pools>` are distributed in a +separate package from the `!psycopg` package itself, in order to allow a +different release cycle. + +In order to use the pool you must install the ``pool`` extra, using ``pip +install "psycopg[pool]"``, or install the `psycopg_pool` package separately, +which would allow to specify the release to install more precisely. + + +Handling dependencies +--------------------- + +If you need to specify your project dependencies (for instance in a +``requirements.txt`` file, ``setup.py``, ``pyproject.toml`` dependencies...) +you should probably specify one of the following: + +- If your project is a library, add a dependency on ``psycopg``. This will + make sure that your library will have the ``psycopg`` package with the right + interface and leaves the possibility of choosing a specific implementation + to the end user of your library. + +- If your project is a final application (e.g. a service running on a server) + you can require a specific implementation, for instance ``psycopg[c]``, + after you have made sure that the prerequisites are met (e.g. the depending + libraries and tools are installed in the host machine). + +In both cases you can specify which version of Psycopg to use using +`requirement specifiers`__. + +.. __: https://pip.pypa.io/en/stable/cli/pip_install/#requirement-specifiers + +If you want to make sure that a specific implementation is used you can +specify the :envvar:`PSYCOPG_IMPL` environment variable: importing the library +will fail if the implementation specified is not available. See :ref:`pq-impl`. diff --git a/docs/basic/params.rst b/docs/basic/params.rst new file mode 100644 index 0000000..a733f07 --- /dev/null +++ b/docs/basic/params.rst @@ -0,0 +1,242 @@ +.. currentmodule:: psycopg + +.. index:: + pair: Query; Parameters + +.. _query-parameters: + +Passing parameters to SQL queries +================================= + +Most of the times, writing a program you will have to mix bits of SQL +statements with values provided by the rest of the program: + +.. code:: + + SELECT some, fields FROM some_table WHERE id = ... + +:sql:`id` equals what? Probably you will have a Python value you are looking +for. + + +`!execute()` arguments +---------------------- + +Passing parameters to a SQL statement happens in functions such as +`Cursor.execute()` by using ``%s`` placeholders in the SQL statement, and +passing a sequence of values as the second argument of the function. For +example the Python function call: + +.. code:: python + + cur.execute(""" + INSERT INTO some_table (id, created_at, last_name) + VALUES (%s, %s, %s); + """, + (10, datetime.date(2020, 11, 18), "O'Reilly")) + +is *roughly* equivalent to the SQL command: + +.. code-block:: sql + + INSERT INTO some_table (id, created_at, last_name) + VALUES (10, '2020-11-18', 'O''Reilly'); + +Note that the parameters will not be really merged to the query: query and the +parameters are sent to the server separately: see :ref:`server-side-binding` +for details. + +Named arguments are supported too using :samp:`%({name})s` placeholders in the +query and specifying the values into a mapping. Using named arguments allows +to specify the values in any order and to repeat the same value in several +places in the query:: + + cur.execute(""" + INSERT INTO some_table (id, created_at, updated_at, last_name) + VALUES (%(id)s, %(created)s, %(created)s, %(name)s); + """, + {'id': 10, 'name': "O'Reilly", 'created': datetime.date(2020, 11, 18)}) + +Using characters ``%``, ``(``, ``)`` in the argument names is not supported. + +When parameters are used, in order to include a literal ``%`` in the query you +can use the ``%%`` string:: + + cur.execute("SELECT (%s % 2) = 0 AS even", (10,)) # WRONG + cur.execute("SELECT (%s %% 2) = 0 AS even", (10,)) # correct + +While the mechanism resembles regular Python strings manipulation, there are a +few subtle differences you should care about when passing parameters to a +query. + +- The Python string operator ``%`` *must not be used*: the `~cursor.execute()` + method accepts a tuple or dictionary of values as second parameter. + |sql-warn|__: + + .. |sql-warn| replace:: **Never** use ``%`` or ``+`` to merge values + into queries + + .. code:: python + + cur.execute("INSERT INTO numbers VALUES (%s, %s)" % (10, 20)) # WRONG + cur.execute("INSERT INTO numbers VALUES (%s, %s)", (10, 20)) # correct + + .. __: sql-injection_ + +- For positional variables binding, *the second argument must always be a + sequence*, even if it contains a single variable (remember that Python + requires a comma to create a single element tuple):: + + cur.execute("INSERT INTO foo VALUES (%s)", "bar") # WRONG + cur.execute("INSERT INTO foo VALUES (%s)", ("bar")) # WRONG + cur.execute("INSERT INTO foo VALUES (%s)", ("bar",)) # correct + cur.execute("INSERT INTO foo VALUES (%s)", ["bar"]) # correct + +- The placeholder *must not be quoted*:: + + cur.execute("INSERT INTO numbers VALUES ('%s')", ("Hello",)) # WRONG + cur.execute("INSERT INTO numbers VALUES (%s)", ("Hello",)) # correct + +- The variables placeholder *must always be a* ``%s``, even if a different + placeholder (such as a ``%d`` for integers or ``%f`` for floats) may look + more appropriate for the type. You may find other placeholders used in + Psycopg queries (``%b`` and ``%t``) but they are not related to the + type of the argument: see :ref:`binary-data` if you want to read more:: + + cur.execute("INSERT INTO numbers VALUES (%d)", (10,)) # WRONG + cur.execute("INSERT INTO numbers VALUES (%s)", (10,)) # correct + +- Only query values should be bound via this method: it shouldn't be used to + merge table or field names to the query. If you need to generate SQL queries + dynamically (for instance choosing a table name at runtime) you can use the + functionalities provided in the `psycopg.sql` module:: + + cur.execute("INSERT INTO %s VALUES (%s)", ('numbers', 10)) # WRONG + cur.execute( # correct + SQL("INSERT INTO {} VALUES (%s)").format(Identifier('numbers')), + (10,)) + + +.. index:: Security, SQL injection + +.. _sql-injection: + +Danger: SQL injection +--------------------- + +The SQL representation of many data types is often different from their Python +string representation. The typical example is with single quotes in strings: +in SQL single quotes are used as string literal delimiters, so the ones +appearing inside the string itself must be escaped, whereas in Python single +quotes can be left unescaped if the string is delimited by double quotes. + +Because of the difference, sometimes subtle, between the data types +representations, a naïve approach to query strings composition, such as using +Python strings concatenation, is a recipe for *terrible* problems:: + + SQL = "INSERT INTO authors (name) VALUES ('%s')" # NEVER DO THIS + data = ("O'Reilly", ) + cur.execute(SQL % data) # THIS WILL FAIL MISERABLY + # SyntaxError: syntax error at or near "Reilly" + +If the variables containing the data to send to the database come from an +untrusted source (such as data coming from a form on a web site) an attacker +could easily craft a malformed string, either gaining access to unauthorized +data or performing destructive operations on the database. This form of attack +is called `SQL injection`_ and is known to be one of the most widespread forms +of attack on database systems. Before continuing, please print `this page`__ +as a memo and hang it onto your desk. + +.. _SQL injection: https://en.wikipedia.org/wiki/SQL_injection +.. __: https://xkcd.com/327/ + +Psycopg can :ref:`automatically convert Python objects to SQL +values<types-adaptation>`: using this feature your code will be more robust +and reliable. We must stress this point: + +.. warning:: + + - Don't manually merge values to a query: hackers from a foreign country + will break into your computer and steal not only your disks, but also + your cds, leaving you only with the three most embarrassing records you + ever bought. On cassette tapes. + + - If you use the ``%`` operator to merge values to a query, con artists + will seduce your cat, who will run away taking your credit card + and your sunglasses with them. + + - If you use ``+`` to merge a textual value to a string, bad guys in + balaclava will find their way to your fridge, drink all your beer, and + leave your toilet seat up and your toilet paper in the wrong orientation. + + - You don't want to manually merge values to a query: :ref:`use the + provided methods <query-parameters>` instead. + +The correct way to pass variables in a SQL command is using the second +argument of the `Cursor.execute()` method:: + + SQL = "INSERT INTO authors (name) VALUES (%s)" # Note: no quotes + data = ("O'Reilly", ) + cur.execute(SQL, data) # Note: no % operator + +.. note:: + + Python static code checkers are not quite there yet, but, in the future, + it will be possible to check your code for improper use of string + expressions in queries. See :ref:`literal-string` for details. + +.. seealso:: + + Now that you know how to pass parameters to queries, you can take a look + at :ref:`how Psycopg converts data types <types-adaptation>`. + + +.. index:: + pair: Binary; Parameters + +.. _binary-data: + +Binary parameters and results +----------------------------- + +PostgreSQL has two different ways to transmit data between client and server: +`~psycopg.pq.Format.TEXT`, always available, and `~psycopg.pq.Format.BINARY`, +available most of the times but not always. Usually the binary format is more +efficient to use. + +Psycopg can support both formats for each data type. Whenever a value +is passed to a query using the normal ``%s`` placeholder, the best format +available is chosen (often, but not always, the binary format is picked as the +best choice). + +If you have a reason to select explicitly the binary format or the text format +for a value you can use respectively a ``%b`` placeholder or a ``%t`` +placeholder instead of the normal ``%s``. `~Cursor.execute()` will fail if a +`~psycopg.adapt.Dumper` for the right data type and format is not available. + +The same two formats, text or binary, are used by PostgreSQL to return data +from a query to the client. Unlike with parameters, where you can choose the +format value-by-value, all the columns returned by a query will have the same +format. Every type returned by the query should have a `~psycopg.adapt.Loader` +configured, otherwise the data will be returned as unparsed `!str` (for text +results) or buffer (for binary results). + +.. note:: + The `pg_type`_ table defines which format is supported for each PostgreSQL + data type. Text input/output is managed by the functions declared in the + ``typinput`` and ``typoutput`` fields (always present), binary + input/output is managed by the ``typsend`` and ``typreceive`` (which are + optional). + + .. _pg_type: https://www.postgresql.org/docs/current/catalog-pg-type.html + +Because not every PostgreSQL type supports binary output, by default, the data +will be returned in text format. In order to return data in binary format you +can create the cursor using `Connection.cursor`\ `!(binary=True)` or execute +the query using `Cursor.execute`\ `!(binary=True)`. A case in which +requesting binary results is a clear winner is when you have large binary data +in the database, such as images:: + + cur.execute( + "SELECT image_data FROM images WHERE id = %s", [image_id], binary=True) + data = cur.fetchone()[0] diff --git a/docs/basic/pgtypes.rst b/docs/basic/pgtypes.rst new file mode 100644 index 0000000..14ee5be --- /dev/null +++ b/docs/basic/pgtypes.rst @@ -0,0 +1,389 @@ +.. currentmodule:: psycopg + +.. index:: + single: Adaptation + pair: Objects; Adaptation + single: Data types; Adaptation + +.. _extra-adaptation: + +Adapting other PostgreSQL types +=============================== + +PostgreSQL offers other data types which don't map to native Python types. +Psycopg offers wrappers and conversion functions to allow their use. + + +.. index:: + pair: Composite types; Data types + pair: tuple; Adaptation + pair: namedtuple; Adaptation + +.. _adapt-composite: + +Composite types casting +----------------------- + +Psycopg can adapt PostgreSQL composite types (either created with the |CREATE +TYPE|_ command or implicitly defined after a table row type) to and from +Python tuples, `~collections.namedtuple`, or any other suitable object +configured. + +.. |CREATE TYPE| replace:: :sql:`CREATE TYPE` +.. _CREATE TYPE: https://www.postgresql.org/docs/current/static/sql-createtype.html + +Before using a composite type it is necessary to get information about it +using the `~psycopg.types.composite.CompositeInfo` class and to register it +using `~psycopg.types.composite.register_composite()`. + +.. autoclass:: psycopg.types.composite.CompositeInfo + + `!CompositeInfo` is a `~psycopg.types.TypeInfo` subclass: check its + documentation for the generic usage, especially the + `~psycopg.types.TypeInfo.fetch()` method. + + .. attribute:: python_type + + After `register_composite()` is called, it will contain the python type + mapping to the registered composite. + +.. autofunction:: psycopg.types.composite.register_composite + + After registering, fetching data of the registered composite will invoke + `!factory` to create corresponding Python objects. + + If no factory is specified, a `~collection.namedtuple` is created and used + to return data. + + If the `!factory` is a type (and not a generic callable), then dumpers for + that type are created and registered too, so that passing objects of that + type to a query will adapt them to the registered type. + +Example:: + + >>> from psycopg.types.composite import CompositeInfo, register_composite + + >>> conn.execute("CREATE TYPE card AS (value int, suit text)") + + >>> info = CompositeInfo.fetch(conn, "card") + >>> register_composite(info, conn) + + >>> my_card = info.python_type(8, "hearts") + >>> my_card + card(value=8, suit='hearts') + + >>> conn.execute( + ... "SELECT pg_typeof(%(card)s), (%(card)s).suit", {"card": my_card} + ... ).fetchone() + ('card', 'hearts') + + >>> conn.execute("SELECT (%s, %s)::card", [1, "spades"]).fetchone()[0] + card(value=1, suit='spades') + + +Nested composite types are handled as expected, provided that the type of the +composite components are registered as well:: + + >>> conn.execute("CREATE TYPE card_back AS (face card, back text)") + + >>> info2 = CompositeInfo.fetch(conn, "card_back") + >>> register_composite(info2, conn) + + >>> conn.execute("SELECT ((8, 'hearts'), 'blue')::card_back").fetchone()[0] + card_back(face=card(value=8, suit='hearts'), back='blue') + + +.. index:: + pair: range; Data types + +.. _adapt-range: + +Range adaptation +---------------- + +PostgreSQL `range types`__ are a family of data types representing a range of +values between two elements. The type of the element is called the range +*subtype*. PostgreSQL offers a few built-in range types and allows the +definition of custom ones. + +.. __: https://www.postgresql.org/docs/current/rangetypes.html + +All the PostgreSQL range types are loaded as the `~psycopg.types.range.Range` +Python type, which is a `~typing.Generic` type and can hold bounds of +different types. + +.. autoclass:: psycopg.types.range.Range + + This Python type is only used to pass and retrieve range values to and + from PostgreSQL and doesn't attempt to replicate the PostgreSQL range + features: it doesn't perform normalization and doesn't implement all the + operators__ supported by the database. + + PostgreSQL will perform normalisation on `!Range` objects used as query + parameters, so, when they are fetched back, they will be found in the + normal form (for instance ranges on integers will have `[)` bounds). + + .. __: https://www.postgresql.org/docs/current/static/functions-range.html#RANGE-OPERATORS-TABLE + + `!Range` objects are immutable, hashable, and support the `!in` operator + (checking if an element is within the range). They can be tested for + equivalence. Empty ranges evaluate to `!False` in a boolean context, + nonempty ones evaluate to `!True`. + + `!Range` objects have the following attributes: + + .. autoattribute:: isempty + .. autoattribute:: lower + .. autoattribute:: upper + .. autoattribute:: lower_inc + .. autoattribute:: upper_inc + .. autoattribute:: lower_inf + .. autoattribute:: upper_inf + +The built-in range objects are adapted automatically: if a `!Range` objects +contains `~datetime.date` bounds, it is dumped using the :sql:`daterange` OID, +and of course :sql:`daterange` values are loaded back as `!Range[date]`. + +If you create your own range type you can use `~psycopg.types.range.RangeInfo` +and `~psycopg.types.range.register_range()` to associate the range type with +its subtype and make it work like the builtin ones. + +.. autoclass:: psycopg.types.range.RangeInfo + + `!RangeInfo` is a `~psycopg.types.TypeInfo` subclass: check its + documentation for generic details, especially the + `~psycopg.types.TypeInfo.fetch()` method. + +.. autofunction:: psycopg.types.range.register_range + +Example:: + + >>> from psycopg.types.range import Range, RangeInfo, register_range + + >>> conn.execute("CREATE TYPE strrange AS RANGE (SUBTYPE = text)") + >>> info = RangeInfo.fetch(conn, "strrange") + >>> register_range(info, conn) + + >>> conn.execute("SELECT pg_typeof(%s)", [Range("a", "z")]).fetchone()[0] + 'strrange' + + >>> conn.execute("SELECT '[a,z]'::strrange").fetchone()[0] + Range('a', 'z', '[]') + + +.. index:: + pair: range; Data types + +.. _adapt-multirange: + +Multirange adaptation +--------------------- + +Since PostgreSQL 14, every range type is associated with a multirange__, a +type representing a disjoint set of ranges. A multirange is +automatically available for every range, built-in and user-defined. + +.. __: https://www.postgresql.org/docs/current/rangetypes.html + +All the PostgreSQL range types are loaded as the +`~psycopg.types.multirange.Multirange` Python type, which is a mutable +sequence of `~psycopg.types.range.Range` elements. + +.. autoclass:: psycopg.types.multirange.Multirange + + This Python type is only used to pass and retrieve multirange values to + and from PostgreSQL and doesn't attempt to replicate the PostgreSQL + multirange features: overlapping items are not merged, empty ranges are + not discarded, the items are not ordered, the behaviour of `multirange + operators`__ is not replicated in Python. + + PostgreSQL will perform normalisation on `!Multirange` objects used as + query parameters, so, when they are fetched back, they will be found + ordered, with overlapping ranges merged, etc. + + .. __: https://www.postgresql.org/docs/current/static/functions-range.html#MULTIRANGE-OPERATORS-TABLE + + `!Multirange` objects are a `~collections.abc.MutableSequence` and are + totally ordered: they behave pretty much like a list of `!Range`. Like + Range, they are `~typing.Generic` on the subtype of their range, so you + can declare a variable to be `!Multirange[date]` and mypy will complain if + you try to add it a `Range[Decimal]`. + +Like for `~psycopg.types.range.Range`, built-in multirange objects are adapted +automatically: if a `!Multirange` object contains `!Range` with +`~datetime.date` bounds, it is dumped using the :sql:`datemultirange` OID, and +:sql:`datemultirange` values are loaded back as `!Multirange[date]`. + +If you have created your own range type you can use +`~psycopg.types.multirange.MultirangeInfo` and +`~psycopg.types.multirange.register_multirange()` to associate the resulting +multirange type with its subtype and make it work like the builtin ones. + +.. autoclass:: psycopg.types.multirange.MultirangeInfo + + `!MultirangeInfo` is a `~psycopg.types.TypeInfo` subclass: check its + documentation for generic details, especially the + `~psycopg.types.TypeInfo.fetch()` method. + +.. autofunction:: psycopg.types.multirange.register_multirange + +Example:: + + >>> from psycopg.types.multirange import \ + ... Multirange, MultirangeInfo, register_multirange + >>> from psycopg.types.range import Range + + >>> conn.execute("CREATE TYPE strrange AS RANGE (SUBTYPE = text)") + >>> info = MultirangeInfo.fetch(conn, "strmultirange") + >>> register_multirange(info, conn) + + >>> rec = conn.execute( + ... "SELECT pg_typeof(%(mr)s), %(mr)s", + ... {"mr": Multirange([Range("a", "q"), Range("l", "z")])}).fetchone() + + >>> rec[0] + 'strmultirange' + >>> rec[1] + Multirange([Range('a', 'z', '[)')]) + + +.. index:: + pair: hstore; Data types + pair: dict; Adaptation + +.. _adapt-hstore: + +Hstore adaptation +----------------- + +The |hstore|_ data type is a key-value store embedded in PostgreSQL. It +supports GiST or GIN indexes allowing search by keys or key/value pairs as +well as regular BTree indexes for equality, uniqueness etc. + +.. |hstore| replace:: :sql:`hstore` +.. _hstore: https://www.postgresql.org/docs/current/static/hstore.html + +Psycopg can convert Python `!dict` objects to and from |hstore| structures. +Only dictionaries with string keys and values are supported. `!None` is also +allowed as value but not as a key. + +In order to use the |hstore| data type it is necessary to load it in a +database using: + +.. code:: none + + =# CREATE EXTENSION hstore; + +Because |hstore| is distributed as a contrib module, its oid is not well +known, so it is necessary to use `!TypeInfo`\.\ +`~psycopg.types.TypeInfo.fetch()` to query the database and get its oid. The +resulting object can be passed to +`~psycopg.types.hstore.register_hstore()` to configure dumping `!dict` to +|hstore| and parsing |hstore| back to `!dict`, in the context where the +adapter is registered. + +.. autofunction:: psycopg.types.hstore.register_hstore + +Example:: + + >>> from psycopg.types import TypeInfo + >>> from psycopg.types.hstore import register_hstore + + >>> info = TypeInfo.fetch(conn, "hstore") + >>> register_hstore(info, conn) + + >>> conn.execute("SELECT pg_typeof(%s)", [{"a": "b"}]).fetchone()[0] + 'hstore' + + >>> conn.execute("SELECT 'foo => bar'::hstore").fetchone()[0] + {'foo': 'bar'} + + +.. index:: + pair: geometry; Data types + single: PostGIS; Data types + +.. _adapt-shapely: + +Geometry adaptation using Shapely +--------------------------------- + +When using the PostGIS_ extension, it can be useful to retrieve geometry_ +values and have them automatically converted to Shapely_ instances. Likewise, +you may want to store such instances in the database and have the conversion +happen automatically. + +.. warning:: + Psycopg doesn't have a dependency on the ``shapely`` package: you should + install the library as an additional dependency of your project. + +.. warning:: + This module is experimental and might be changed in the future according + to users' feedback. + +.. _PostGIS: https://postgis.net/ +.. _geometry: https://postgis.net/docs/geometry.html +.. _Shapely: https://github.com/Toblerity/Shapely +.. _shape: https://shapely.readthedocs.io/en/stable/manual.html#shapely.geometry.shape + +Since PostgGIS is an extension, the :sql:`geometry` type oid is not well +known, so it is necessary to use `!TypeInfo`\.\ +`~psycopg.types.TypeInfo.fetch()` to query the database and find it. The +resulting object can be passed to `~psycopg.types.shapely.register_shapely()` +to configure dumping `shape`_ instances to :sql:`geometry` columns and parsing +:sql:`geometry` data back to `!shape` instances, in the context where the +adapters are registered. + +.. function:: psycopg.types.shapely.register_shapely + + Register Shapely dumper and loaders. + + After invoking this function on an adapter, the queries retrieving + PostGIS geometry objects will return Shapely's shape object instances + both in text and binary mode. + + Similarly, shape objects can be sent to the database. + + This requires the Shapely library to be installed. + + :param info: The object with the information about the geometry type. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + +Example:: + + >>> from psycopg.types import TypeInfo + >>> from psycopg.types.shapely import register_shapely + >>> from shapely.geometry import Point + + >>> info = TypeInfo.fetch(conn, "geometry") + >>> register_shapely(info, conn) + + >>> conn.execute("SELECT pg_typeof(%s)", [Point(1.2, 3.4)]).fetchone()[0] + 'geometry' + + >>> conn.execute(""" + ... SELECT ST_GeomFromGeoJSON('{ + ... "type":"Point", + ... "coordinates":[-48.23456,20.12345]}') + ... """).fetchone()[0] + <shapely.geometry.multipolygon.MultiPolygon object at 0x7fb131f3cd90> + +Notice that, if the geometry adapters are registered on a specific object (a +connection or cursor), other connections and cursors will be unaffected:: + + >>> conn2 = psycopg.connect(CONN_STR) + >>> conn2.execute(""" + ... SELECT ST_GeomFromGeoJSON('{ + ... "type":"Point", + ... "coordinates":[-48.23456,20.12345]}') + ... """).fetchone()[0] + '0101000020E61000009279E40F061E48C0F2B0506B9A1F3440' + diff --git a/docs/basic/transactions.rst b/docs/basic/transactions.rst new file mode 100644 index 0000000..b976046 --- /dev/null +++ b/docs/basic/transactions.rst @@ -0,0 +1,388 @@ +.. currentmodule:: psycopg + +.. index:: Transactions management +.. index:: InFailedSqlTransaction +.. index:: idle in transaction + +.. _transactions: + +Transactions management +======================= + +Psycopg has a behaviour that may seem surprising compared to +:program:`psql`: by default, any database operation will start a new +transaction. As a consequence, changes made by any cursor of the connection +will not be visible until `Connection.commit()` is called, and will be +discarded by `Connection.rollback()`. The following operation on the same +connection will start a new transaction. + +If a database operation fails, the server will refuse further commands, until +a `~rollback()` is called. + +If the cursor is closed with a transaction open, no COMMIT command is sent to +the server, which will then discard the connection. Certain middleware (such +as PgBouncer) will also discard a connection left in transaction state, so, if +possible you will want to commit or rollback a connection before finishing +working with it. + +An example of what will happen, the first time you will use Psycopg (and to be +disappointed by it), is likely: + +.. code:: python + + conn = psycopg.connect() + + # Creating a cursor doesn't start a transaction or affect the connection + # in any way. + cur = conn.cursor() + + cur.execute("SELECT count(*) FROM my_table") + # This function call executes: + # - BEGIN + # - SELECT count(*) FROM my_table + # So now a transaction has started. + + # If your program spends a long time in this state, the server will keep + # a connection "idle in transaction", which is likely something undesired + + cur.execute("INSERT INTO data VALUES (%s)", ("Hello",)) + # This statement is executed inside the transaction + + conn.close() + # No COMMIT was sent: the INSERT was discarded. + +There are a few things going wrong here, let's see how they can be improved. + +One obvious problem after the run above is that, firing up :program:`psql`, +you will see no new record in the table ``data``. One way to fix the problem +is to call `!conn.commit()` before closing the connection. Thankfully, if you +use the :ref:`connection context <with-connection>`, Psycopg will commit the +connection at the end of the block (or roll it back if the block is exited +with an exception): + +The code modified using a connection context will result in the following +sequence of database statements: + +.. code-block:: python + :emphasize-lines: 1 + + with psycopg.connect() as conn: + + cur = conn.cursor() + + cur.execute("SELECT count(*) FROM my_table") + # This function call executes: + # - BEGIN + # - SELECT count(*) FROM my_table + # So now a transaction has started. + + cur.execute("INSERT INTO data VALUES (%s)", ("Hello",)) + # This statement is executed inside the transaction + + # No exception at the end of the block: + # COMMIT is executed. + +This way we don't have to remember to call neither `!close()` nor `!commit()` +and the database operations actually have a persistent effect. The code might +still do something you don't expect: keep a transaction from the first +operation to the connection closure. You can have a finer control over the +transactions using an :ref:`autocommit transaction <autocommit>` and/or +:ref:`transaction contexts <transaction-context>`. + +.. warning:: + + By default even a simple :sql:`SELECT` will start a transaction: in + long-running programs, if no further action is taken, the session will + remain *idle in transaction*, an undesirable condition for several + reasons (locks are held by the session, tables bloat...). For long lived + scripts, either make sure to terminate a transaction as soon as possible or + use an `~Connection.autocommit` connection. + +.. hint:: + + If a database operation fails with an error message such as + *InFailedSqlTransaction: current transaction is aborted, commands ignored + until end of transaction block*, it means that **a previous operation + failed** and the database session is in a state of error. You need to call + `~Connection.rollback()` if you want to keep on using the same connection. + + +.. _autocommit: + +Autocommit transactions +----------------------- + +The manual commit requirement can be suspended using `~Connection.autocommit`, +either as connection attribute or as `~psycopg.Connection.connect()` +parameter. This may be required to run operations that cannot be executed +inside a transaction, such as :sql:`CREATE DATABASE`, :sql:`VACUUM`, +:sql:`CALL` on `stored procedures`__ using transaction control. + +.. __: https://www.postgresql.org/docs/current/xproc.html + +With an autocommit transaction, the above sequence of operation results in: + +.. code-block:: python + :emphasize-lines: 1 + + with psycopg.connect(autocommit=True) as conn: + + cur = conn.cursor() + + cur.execute("SELECT count(*) FROM my_table") + # This function call now only executes: + # - SELECT count(*) FROM my_table + # and no transaction starts. + + cur.execute("INSERT INTO data VALUES (%s)", ("Hello",)) + # The result of this statement is persisted immediately by the database + + # The connection is closed at the end of the block but, because it is not + # in a transaction state, no COMMIT is executed. + +An autocommit transaction behaves more as someone coming from :program:`psql` +would expect. This has a beneficial performance effect, because less queries +are sent and less operations are performed by the database. The statements, +however, are not executed in an atomic transaction; if you need to execute +certain operations inside a transaction, you can achieve that with an +autocommit connection too, using an explicit :ref:`transaction block +<transaction-context>`. + + +.. _transaction-context: + +Transaction contexts +-------------------- + +A more transparent way to make sure that transactions are finalised at the +right time is to use `!with` `Connection.transaction()` to create a +transaction context. When the context is entered, a transaction is started; +when leaving the context the transaction is committed, or it is rolled back if +an exception is raised inside the block. + +Continuing the example above, if you want to use an autocommit connection but +still wrap selected groups of commands inside an atomic transaction, you can +use a `!transaction()` context: + +.. code-block:: python + :emphasize-lines: 8 + + with psycopg.connect(autocommit=True) as conn: + + cur = conn.cursor() + + cur.execute("SELECT count(*) FROM my_table") + # The connection is autocommit, so no BEGIN executed. + + with conn.transaction(): + # BEGIN is executed, a transaction started + + cur.execute("INSERT INTO data VALUES (%s)", ("Hello",)) + cur.execute("INSERT INTO times VALUES (now())") + # These two operation run atomically in the same transaction + + # COMMIT is executed at the end of the block. + # The connection is in idle state again. + + # The connection is closed at the end of the block. + + +Note that connection blocks can also be used with non-autocommit connections: +in this case you still need to pay attention to eventual transactions started +automatically. If an operation starts an implicit transaction, a +`!transaction()` block will only manage :ref:`a savepoint sub-transaction +<nested-transactions>`, leaving the caller to deal with the main transaction, +as explained in :ref:`transactions`: + +.. code:: python + + conn = psycopg.connect() + + cur = conn.cursor() + + cur.execute("SELECT count(*) FROM my_table") + # This function call executes: + # - BEGIN + # - SELECT count(*) FROM my_table + # So now a transaction has started. + + with conn.transaction(): + # The block starts with a transaction already open, so it will execute + # - SAVEPOINT + + cur.execute("INSERT INTO data VALUES (%s)", ("Hello",)) + + # The block was executing a sub-transaction so on exit it will only run: + # - RELEASE SAVEPOINT + # The transaction is still on. + + conn.close() + # No COMMIT was sent: the INSERT was discarded. + +If a `!transaction()` block starts when no transaction is active then it will +manage a proper transaction. In essence, a transaction context tries to leave +a connection in the state it found it, and leaves you to deal with the wider +context. + +.. hint:: + The interaction between non-autocommit transactions and transaction + contexts is probably surprising. Although the non-autocommit default is + what's demanded by the DBAPI, the personal preference of several experienced + developers is to: + + - use a connection block: ``with psycopg.connect(...) as conn``; + - use an autocommit connection, either passing `!autocommit=True` as + `!connect()` parameter or setting the attribute ``conn.autocommit = + True``; + - use `!with conn.transaction()` blocks to manage transactions only where + needed. + + +.. _nested-transactions: + +Nested transactions +^^^^^^^^^^^^^^^^^^^ + +Transaction blocks can be also nested (internal transaction blocks are +implemented using SAVEPOINT__): an exception raised inside an inner block +has a chance of being handled and not completely fail outer operations. The +following is an example where a series of operations interact with the +database: operations are allowed to fail; at the end we also want to store the +number of operations successfully processed. + +.. __: https://www.postgresql.org/docs/current/sql-savepoint.html + +.. code:: python + + with conn.transaction() as tx1: + num_ok = 0 + for operation in operations: + try: + with conn.transaction() as tx2: + unreliable_operation(conn, operation) + except Exception: + logger.exception(f"{operation} failed") + else: + num_ok += 1 + + save_number_of_successes(conn, num_ok) + +If `!unreliable_operation()` causes an error, including an operation causing a +database error, all its changes will be reverted. The exception bubbles up +outside the block: in the example it is intercepted by the `!try` so that the +loop can complete. The outermost block is unaffected (unless other errors +happen there). + +You can also write code to explicitly roll back any currently active +transaction block, by raising the `Rollback` exception. The exception "jumps" +to the end of a transaction block, rolling back its transaction but allowing +the program execution to continue from there. By default the exception rolls +back the innermost transaction block, but any current block can be specified +as the target. In the following example, a hypothetical `!CancelCommand` +may stop the processing and cancel any operation previously performed, +but not entirely committed yet. + +.. code:: python + + from psycopg import Rollback + + with conn.transaction() as outer_tx: + for command in commands(): + with conn.transaction() as inner_tx: + if isinstance(command, CancelCommand): + raise Rollback(outer_tx) + process_command(command) + + # If `Rollback` is raised, it would propagate only up to this block, + # and the program would continue from here with no exception. + + +.. _transaction-characteristics: + +Transaction characteristics +--------------------------- + +You can set `transaction parameters`__ for the transactions that Psycopg +handles. They affect the transactions started implicitly by non-autocommit +transactions and the ones started explicitly by `Connection.transaction()` for +both autocommit and non-autocommit transactions. Leaving these parameters as +`!None` will use the server's default behaviour (which is controlled +by server settings such as default_transaction_isolation__). + +.. __: https://www.postgresql.org/docs/current/sql-set-transaction.html +.. __: https://www.postgresql.org/docs/current/runtime-config-client.html + #GUC-DEFAULT-TRANSACTION-ISOLATION + +In order to set these parameters you can use the connection attributes +`~Connection.isolation_level`, `~Connection.read_only`, +`~Connection.deferrable`. For async connections you must use the equivalent +`~AsyncConnection.set_isolation_level()` method and similar. The parameters +can only be changed if there isn't a transaction already active on the +connection. + +.. warning:: + + Applications running at `~IsolationLevel.REPEATABLE_READ` or + `~IsolationLevel.SERIALIZABLE` isolation level are exposed to serialization + failures. `In certain concurrent update cases`__, PostgreSQL will raise an + exception looking like:: + + psycopg2.errors.SerializationFailure: could not serialize access + due to concurrent update + + In this case the application must be prepared to repeat the operation that + caused the exception. + + .. __: https://www.postgresql.org/docs/current/transaction-iso.html + #XACT-REPEATABLE-READ + + +.. index:: + pair: Two-phase commit; Transaction + +.. _two-phase-commit: + +Two-Phase Commit protocol support +--------------------------------- + +.. versionadded:: 3.1 + +Psycopg exposes the two-phase commit features available in PostgreSQL +implementing the `two-phase commit extensions`__ proposed by the DBAPI. + +The DBAPI model of two-phase commit is inspired by the `XA specification`__, +according to which transaction IDs are formed from three components: + +- a format ID (non-negative 32 bit integer) +- a global transaction ID (string not longer than 64 bytes) +- a branch qualifier (string not longer than 64 bytes) + +For a particular global transaction, the first two components will be the same +for all the resources. Every resource will be assigned a different branch +qualifier. + +According to the DBAPI specification, a transaction ID is created using the +`Connection.xid()` method. Once you have a transaction id, a distributed +transaction can be started with `Connection.tpc_begin()`, prepared using +`~Connection.tpc_prepare()` and completed using `~Connection.tpc_commit()` or +`~Connection.tpc_rollback()`. Transaction IDs can also be retrieved from the +database using `~Connection.tpc_recover()` and completed using the above +`!tpc_commit()` and `!tpc_rollback()`. + +PostgreSQL doesn't follow the XA standard though, and the ID for a PostgreSQL +prepared transaction can be any string up to 200 characters long. Psycopg's +`Xid` objects can represent both XA-style transactions IDs (such as the ones +created by the `!xid()` method) and PostgreSQL transaction IDs identified by +an unparsed string. + +The format in which the Xids are converted into strings passed to the +database is the same employed by the `PostgreSQL JDBC driver`__: this should +allow interoperation between tools written in Python and in Java. For example +a recovery tool written in Python would be able to recognize the components of +transactions produced by a Java program. + +For further details see the documentation for the :ref:`tpc-methods`. + +.. __: https://www.python.org/dev/peps/pep-0249/#optional-two-phase-commit-extensions +.. __: https://publications.opengroup.org/c193 +.. __: https://jdbc.postgresql.org/ diff --git a/docs/basic/usage.rst b/docs/basic/usage.rst new file mode 100644 index 0000000..6c69fe8 --- /dev/null +++ b/docs/basic/usage.rst @@ -0,0 +1,232 @@ +.. currentmodule:: psycopg + +.. _module-usage: + +Basic module usage +================== + +The basic Psycopg usage is common to all the database adapters implementing +the `DB-API`__ protocol. Other database adapters, such as the builtin +`sqlite3` or `psycopg2`, have roughly the same pattern of interaction. + +.. __: https://www.python.org/dev/peps/pep-0249/ + + +.. index:: + pair: Example; Usage + +.. _usage: + +Main objects in Psycopg 3 +------------------------- + +Here is an interactive session showing some of the basic commands: + +.. code:: python + + # Note: the module name is psycopg, not psycopg3 + import psycopg + + # Connect to an existing database + with psycopg.connect("dbname=test user=postgres") as conn: + + # Open a cursor to perform database operations + with conn.cursor() as cur: + + # Execute a command: this creates a new table + cur.execute(""" + CREATE TABLE test ( + id serial PRIMARY KEY, + num integer, + data text) + """) + + # Pass data to fill a query placeholders and let Psycopg perform + # the correct conversion (no SQL injections!) + cur.execute( + "INSERT INTO test (num, data) VALUES (%s, %s)", + (100, "abc'def")) + + # Query the database and obtain data as Python objects. + cur.execute("SELECT * FROM test") + cur.fetchone() + # will return (1, 100, "abc'def") + + # You can use `cur.fetchmany()`, `cur.fetchall()` to return a list + # of several records, or even iterate on the cursor + for record in cur: + print(record) + + # Make the changes to the database persistent + conn.commit() + + +In the example you can see some of the main objects and methods and how they +relate to each other: + +- The function `~Connection.connect()` creates a new database session and + returns a new `Connection` instance. `AsyncConnection.connect()` + creates an `asyncio` connection instead. + +- The `~Connection` class encapsulates a database session. It allows to: + + - create new `~Cursor` instances using the `~Connection.cursor()` method to + execute database commands and queries, + + - terminate transactions using the methods `~Connection.commit()` or + `~Connection.rollback()`. + +- The class `~Cursor` allows interaction with the database: + + - send commands to the database using methods such as `~Cursor.execute()` + and `~Cursor.executemany()`, + + - retrieve data from the database, iterating on the cursor or using methods + such as `~Cursor.fetchone()`, `~Cursor.fetchmany()`, `~Cursor.fetchall()`. + +- Using these objects as context managers (i.e. using `!with`) will make sure + to close them and free their resources at the end of the block (notice that + :ref:`this is different from psycopg2 <diff-with>`). + + +.. seealso:: + + A few important topics you will have to deal with are: + + - :ref:`query-parameters`. + - :ref:`types-adaptation`. + - :ref:`transactions`. + + +Shortcuts +--------- + +The pattern above is familiar to `!psycopg2` users. However, Psycopg 3 also +exposes a few simple extensions which make the above pattern leaner: + +- the `Connection` objects exposes an `~Connection.execute()` method, + equivalent to creating a cursor, calling its `~Cursor.execute()` method, and + returning it. + + .. code:: + + # In Psycopg 2 + cur = conn.cursor() + cur.execute(...) + + # In Psycopg 3 + cur = conn.execute(...) + +- The `Cursor.execute()` method returns `!self`. This means that you can chain + a fetch operation, such as `~Cursor.fetchone()`, to the `!execute()` call: + + .. code:: + + # In Psycopg 2 + cur.execute(...) + record = cur.fetchone() + + cur.execute(...) + for record in cur: + ... + + # In Psycopg 3 + record = cur.execute(...).fetchone() + + for record in cur.execute(...): + ... + +Using them together, in simple cases, you can go from creating a connection to +using a result in a single expression: + +.. code:: + + print(psycopg.connect(DSN).execute("SELECT now()").fetchone()[0]) + # 2042-07-12 18:15:10.706497+01:00 + + +.. index:: + pair: Connection; `!with` + +.. _with-connection: + +Connection context +------------------ + +Psycopg 3 `Connection` can be used as a context manager: + +.. code:: python + + with psycopg.connect() as conn: + ... # use the connection + + # the connection is now closed + +When the block is exited, if there is a transaction open, it will be +committed. If an exception is raised within the block the transaction is +rolled back. In both cases the connection is closed. It is roughly the +equivalent of: + +.. code:: python + + conn = psycopg.connect() + try: + ... # use the connection + except BaseException: + conn.rollback() + else: + conn.commit() + finally: + conn.close() + +.. note:: + This behaviour is not what `!psycopg2` does: in `!psycopg2` :ref:`there is + no final close() <pg2:with>` and the connection can be used in several + `!with` statements to manage different transactions. This behaviour has + been considered non-standard and surprising so it has been replaced by the + more explicit `~Connection.transaction()` block. + +Note that, while the above pattern is what most people would use, `connect()` +doesn't enter a block itself, but returns an "un-entered" connection, so that +it is still possible to use a connection regardless of the code scope and the +developer is free to use (and responsible for calling) `~Connection.commit()`, +`~Connection.rollback()`, `~Connection.close()` as and where needed. + +.. warning:: + If a connection is just left to go out of scope, the way it will behave + with or without the use of a `!with` block is different: + + - if the connection is used without a `!with` block, the server will find + a connection closed INTRANS and roll back the current transaction; + + - if the connection is used with a `!with` block, there will be an + explicit COMMIT and the operations will be finalised. + + You should use a `!with` block when your intention is just to execute a + set of operations and then committing the result, which is the most usual + thing to do with a connection. If your connection life cycle and + transaction pattern is different, and want more control on it, the use + without `!with` might be more convenient. + + See :ref:`transactions` for more information. + +`AsyncConnection` can be also used as context manager, using ``async with``, +but be careful about its quirkiness: see :ref:`async-with` for details. + + +Adapting pyscopg to your program +-------------------------------- + +The above :ref:`pattern of use <usage>` only shows the default behaviour of +the adapter. Psycopg can be customised in several ways, to allow the smoothest +integration between your Python program and your PostgreSQL database: + +- If your program is concurrent and based on `asyncio` instead of on + threads/processes, you can use :ref:`async connections and cursors <async>`. + +- If you want to customise the objects that the cursor returns, instead of + receiving tuples, you can specify your :ref:`row factories <row-factories>`. + +- If you want to customise how Python values and PostgreSQL types are mapped + into each other, beside the :ref:`basic type mapping <types-adaptation>`, + you can :ref:`configure your types <adaptation>`. diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..a20894b --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,110 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + +import sys +from pathlib import Path + +import psycopg + +docs_dir = Path(__file__).parent +sys.path.append(str(docs_dir / "lib")) + + +# -- Project information ----------------------------------------------------- + +project = "psycopg" +copyright = "2020, Daniele Varrazzo and The Psycopg Team" +author = "Daniele Varrazzo" +release = psycopg.__version__ + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx_autodoc_typehints", + "sql_role", + "ticket_role", + "pg3_docs", + "libpq_docs", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] + + +# -- Options for HTML output ------------------------------------------------- + +# The announcement may be in the website but not shipped with the docs +ann_file = docs_dir / "../../templates/docs3-announcement.html" +if ann_file.exists(): + with ann_file.open() as f: + announcement = f.read() +else: + announcement = "" + +html_css_files = ["psycopg.css"] + +# The name of the Pygments (syntax highlighting) style to use. +# Some that I've check don't suck: +# default lovelace tango algol_nu +# list: from pygments.styles import STYLE_MAP; print(sorted(STYLE_MAP.keys())) +pygments_style = "tango" + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "furo" +html_show_sphinx = True +html_show_sourcelink = False +html_theme_options = { + "announcement": announcement, + "sidebar_hide_name": False, + "light_logo": "psycopg.svg", + "dark_logo": "psycopg.svg", + "light_css_variables": { + "admonition-font-size": "1rem", + }, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# The reST default role (used for this markup: `text`) to use for all documents. +default_role = "obj" + +intersphinx_mapping = { + "py": ("https://docs.python.org/3", None), + "pg2": ("https://www.psycopg.org/docs/", None), +} + +autodoc_member_order = "bysource" + +# PostgreSQL docs version to link libpq functions to +libpq_docs_version = "14" + +# Where to point on :ticket: role +ticket_url = "https://github.com/psycopg/psycopg/issues/%s" diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..916eeb0 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,52 @@ +=================================================== +Psycopg 3 -- PostgreSQL database adapter for Python +=================================================== + +Psycopg 3 is a newly designed PostgreSQL_ database adapter for the Python_ +programming language. + +Psycopg 3 presents a familiar interface for everyone who has used +`Psycopg 2`_ or any other `DB-API 2.0`_ database adapter, but allows to use +more modern PostgreSQL and Python features, such as: + +- :ref:`Asynchronous support <async>` +- :ref:`COPY support from Python objects <copy>` +- :ref:`A redesigned connection pool <connection-pools>` +- :ref:`Support for static typing <static-typing>` +- :ref:`Server-side parameters binding <server-side-binding>` +- :ref:`Prepared statements <prepared-statements>` +- :ref:`Statements pipeline <pipeline-mode>` +- :ref:`Binary communication <binary-data>` +- :ref:`Direct access to the libpq functionalities <psycopg.pq>` + +.. _Python: https://www.python.org/ +.. _PostgreSQL: https://www.postgresql.org/ +.. _Psycopg 2: https://www.psycopg.org/docs/ +.. _DB-API 2.0: https://www.python.org/dev/peps/pep-0249/ + + +Documentation +============= + +.. toctree:: + :maxdepth: 2 + + basic/index + advanced/index + api/index + +Release notes +------------- + +.. toctree:: + :maxdepth: 1 + + news + news_pool + + +Indices and tables +------------------ + +* :ref:`genindex` +* :ref:`modindex` diff --git a/docs/lib/libpq_docs.py b/docs/lib/libpq_docs.py new file mode 100644 index 0000000..b8e01f0 --- /dev/null +++ b/docs/lib/libpq_docs.py @@ -0,0 +1,182 @@ +""" +Sphinx plugin to link to the libpq documentation. + +Add the ``:pq:`` role, to create a link to a libpq function, e.g. :: + + :pq:`PQlibVersion()` + +will link to:: + + https://www.postgresql.org/docs/current/libpq-misc.html #LIBPQ-PQLIBVERSION + +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import logging +import urllib.request +from pathlib import Path +from functools import lru_cache +from html.parser import HTMLParser + +from docutils import nodes, utils +from docutils.parsers.rst import roles + +logger = logging.getLogger("sphinx.libpq_docs") + + +class LibpqParser(HTMLParser): + def __init__(self, data, version="current"): + super().__init__() + self.data = data + self.version = version + + self.section_id = None + self.varlist_id = None + self.in_term = False + self.in_func = False + + def handle_starttag(self, tag, attrs): + if tag == "sect1": + self.handle_sect1(tag, attrs) + elif tag == "varlistentry": + self.handle_varlistentry(tag, attrs) + elif tag == "term": + self.in_term = True + elif tag == "function": + self.in_func = True + + def handle_endtag(self, tag): + if tag == "term": + self.in_term = False + elif tag == "function": + self.in_func = False + + def handle_data(self, data): + if not (self.in_term and self.in_func): + return + + self.add_function(data) + + def handle_sect1(self, tag, attrs): + attrs = dict(attrs) + if "id" in attrs: + self.section_id = attrs["id"] + + def handle_varlistentry(self, tag, attrs): + attrs = dict(attrs) + if "id" in attrs: + self.varlist_id = attrs["id"] + + def add_function(self, func_name): + self.data[func_name] = self.get_func_url() + + def get_func_url(self): + assert self.section_id, "<sect1> tag not found" + assert self.varlist_id, "<varlistentry> tag not found" + return self._url_pattern.format( + version=self.version, + section=self.section_id, + func_id=self.varlist_id.upper(), + ) + + _url_pattern = "https://www.postgresql.org/docs/{version}/{section}.html#{func_id}" + + +class LibpqReader: + # must be set before using the rest of the class. + app = None + + _url_pattern = ( + "https://raw.githubusercontent.com/postgres/postgres/REL_{ver}_STABLE" + "/doc/src/sgml/libpq.sgml" + ) + + data = None + + def get_url(self, func): + if not self.data: + self.parse() + + return self.data[func] + + def parse(self): + if not self.local_file.exists(): + self.download() + + logger.info("parsing libpq docs from %s", self.local_file) + self.data = {} + parser = LibpqParser(self.data, version=self.version) + with self.local_file.open("r") as f: + parser.feed(f.read()) + + def download(self): + filename = os.environ.get("LIBPQ_DOCS_FILE") + if filename: + logger.info("reading postgres libpq docs from %s", filename) + with open(filename, "rb") as f: + data = f.read() + else: + logger.info("downloading postgres libpq docs from %s", self.sgml_url) + data = urllib.request.urlopen(self.sgml_url).read() + + with self.local_file.open("wb") as f: + f.write(data) + + @property + def local_file(self): + return Path(self.app.doctreedir) / f"libpq-{self.version}.sgml" + + @property + def sgml_url(self): + return self._url_pattern.format(ver=self.version) + + @property + def version(self): + return self.app.config.libpq_docs_version + + +@lru_cache() +def get_reader(): + return LibpqReader() + + +def pq_role(name, rawtext, text, lineno, inliner, options={}, content=[]): + text = utils.unescape(text) + + reader = get_reader() + if "(" in text: + func, noise = text.split("(", 1) + noise = "(" + noise + + else: + func = text + noise = "" + + try: + url = reader.get_url(func) + except KeyError: + msg = inliner.reporter.warning( + f"function {func} not found in libpq {reader.version} docs" + ) + prb = inliner.problematic(rawtext, rawtext, msg) + return [prb], [msg] + + # For a function f(), include the () in the signature for consistency + # with a normal `thing()` + if noise == "()": + func, noise = func + noise, "" + + the_nodes = [] + the_nodes.append(nodes.reference(func, func, refuri=url)) + if noise: + the_nodes.append(nodes.Text(noise)) + + return [nodes.literal("", "", *the_nodes, **options)], [] + + +def setup(app): + app.add_config_value("libpq_docs_version", "14", "html") + roles.register_local_role("pq", pq_role) + get_reader().app = app diff --git a/docs/lib/pg3_docs.py b/docs/lib/pg3_docs.py new file mode 100644 index 0000000..05a6876 --- /dev/null +++ b/docs/lib/pg3_docs.py @@ -0,0 +1,197 @@ +""" +Customisation for docs generation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +import logging +import importlib +from typing import Dict +from collections import deque + + +def process_docstring(app, what, name, obj, options, lines): + pass + + +def before_process_signature(app, obj, bound_method): + ann = getattr(obj, "__annotations__", {}) + if "return" in ann: + # Drop "return: None" from the function signatures + if ann["return"] is None: + del ann["return"] + + +def process_signature(app, what, name, obj, options, signature, return_annotation): + pass + + +def setup(app): + app.connect("autodoc-process-docstring", process_docstring) + app.connect("autodoc-process-signature", process_signature) + app.connect("autodoc-before-process-signature", before_process_signature) + + import psycopg # type: ignore + + recover_defined_module( + psycopg, skip_modules=["psycopg._dns", "psycopg.types.shapely"] + ) + monkeypatch_autodoc() + + # Disable warnings in sphinx_autodoc_typehints because it doesn't seem that + # there is a workaround for: "WARNING: Cannot resolve forward reference in + # type annotations" + logger = logging.getLogger("sphinx.sphinx_autodoc_typehints") + logger.setLevel(logging.ERROR) + + +# Classes which may have __module__ overwritten +recovered_classes: Dict[type, str] = {} + + +def recover_defined_module(m, skip_modules=()): + """ + Find the module where classes with __module__ attribute hacked were defined. + + Autodoc will get confused and will fail to inspect attribute docstrings + (e.g. from enums and named tuples). + + Save the classes recovered in `recovered_classes`, to be used by + `monkeypatch_autodoc()`. + + """ + mdir = os.path.split(m.__file__)[0] + for fn in walk_modules(mdir): + assert fn.startswith(mdir) + modname = os.path.splitext(fn[len(mdir) + 1 :])[0].replace("/", ".") + modname = f"{m.__name__}.{modname}" + if modname in skip_modules: + continue + with open(fn) as f: + classnames = re.findall(r"^class\s+([^(:]+)", f.read(), re.M) + for cls in classnames: + cls = deep_import(f"{modname}.{cls}") + if cls.__module__ != modname: + recovered_classes[cls] = modname + + +def monkeypatch_autodoc(): + """ + Patch autodoc in order to use information found by `recover_defined_module`. + """ + from sphinx.ext.autodoc import Documenter, AttributeDocumenter + + orig_doc_get_real_modname = Documenter.get_real_modname + orig_attr_get_real_modname = AttributeDocumenter.get_real_modname + orig_attr_add_content = AttributeDocumenter.add_content + + def fixed_doc_get_real_modname(self): + if self.object in recovered_classes: + return recovered_classes[self.object] + return orig_doc_get_real_modname(self) + + def fixed_attr_get_real_modname(self): + if self.parent in recovered_classes: + return recovered_classes[self.parent] + return orig_attr_get_real_modname(self) + + def fixed_attr_add_content(self, more_content): + """ + Replace a docstring such as:: + + .. py:attribute:: ConnectionInfo.dbname + :module: psycopg + + The database name of the connection. + + :rtype: :py:class:`str` + + into: + + .. py:attribute:: ConnectionInfo.dbname + :type: str + :module: psycopg + + The database name of the connection. + + which creates a more compact representation of a property. + + """ + orig_attr_add_content(self, more_content) + if not isinstance(self.object, property): + return + iret, mret = match_in_lines(r"\s*:rtype: (.*)", self.directive.result) + iatt, matt = match_in_lines(r"\.\.", self.directive.result) + if not (mret and matt): + return + self.directive.result.pop(iret) + self.directive.result.insert( + iatt + 1, + f"{self.indent}:type: {unrest(mret.group(1))}", + source=self.get_sourcename(), + ) + + Documenter.get_real_modname = fixed_doc_get_real_modname + AttributeDocumenter.get_real_modname = fixed_attr_get_real_modname + AttributeDocumenter.add_content = fixed_attr_add_content + + +def match_in_lines(pattern, lines): + """Match a regular expression against a list of strings. + + Return the index of the first matched line and the match object. + None, None if nothing matched. + """ + for i, line in enumerate(lines): + m = re.match(pattern, line) + if m: + return i, m + else: + return None, None + + +def unrest(s): + r"""remove the reST markup from a string + + e.g. :py:data:`~typing.Optional`\[:py:class:`int`] -> Optional[int] + + required because :type: does the types lookup itself apparently. + """ + s = re.sub(r":[^`]*:`~?([^`]*)`", r"\1", s) # drop role + s = re.sub(r"\\(.)", r"\1", s) # drop escape + + # note that ~psycopg.pq.ConnStatus is converted to pq.ConnStatus + # which should be interpreted well if currentmodule is set ok. + s = re.sub(r"(?:typing|psycopg)\.", "", s) # drop unneeded modules + s = re.sub(r"~", "", s) # drop the tilde + + return s + + +def walk_modules(d): + for root, dirs, files in os.walk(d): + for f in files: + if f.endswith(".py"): + yield f"{root}/{f}" + + +def deep_import(name): + parts = deque(name.split(".")) + seen = [] + if not parts: + raise ValueError("name must be a dot-separated name") + + seen.append(parts.popleft()) + thing = importlib.import_module(seen[-1]) + while parts: + attr = parts.popleft() + seen.append(attr) + + if hasattr(thing, attr): + thing = getattr(thing, attr) + else: + thing = importlib.import_module(".".join(seen)) + + return thing diff --git a/docs/lib/sql_role.py b/docs/lib/sql_role.py new file mode 100644 index 0000000..a40c9f4 --- /dev/null +++ b/docs/lib/sql_role.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +""" + sql role + ~~~~~~~~ + + An interpreted text role to style SQL syntax in Psycopg documentation. + + :copyright: Copyright 2010 by Daniele Varrazzo. + :copyright: Copyright 2020 The Psycopg Team. +""" + +from docutils import nodes, utils +from docutils.parsers.rst import roles + + +def sql_role(name, rawtext, text, lineno, inliner, options={}, content=[]): + text = utils.unescape(text) + options["classes"] = ["sql"] + return [nodes.literal(rawtext, text, **options)], [] + + +def setup(app): + roles.register_local_role("sql", sql_role) diff --git a/docs/lib/ticket_role.py b/docs/lib/ticket_role.py new file mode 100644 index 0000000..24ec873 --- /dev/null +++ b/docs/lib/ticket_role.py @@ -0,0 +1,50 @@ +# type: ignore +""" + ticket role + ~~~~~~~~~~~ + + An interpreted text role to link docs to tickets issues. + + :copyright: Copyright 2013 by Daniele Varrazzo. + :copyright: Copyright 2021 The Psycopg Team +""" + +import re +from docutils import nodes, utils +from docutils.parsers.rst import roles + + +def ticket_role(name, rawtext, text, lineno, inliner, options={}, content=[]): + cfg = inliner.document.settings.env.app.config + if cfg.ticket_url is None: + msg = inliner.reporter.warning( + "ticket not configured: please configure ticket_url in conf.py" + ) + prb = inliner.problematic(rawtext, rawtext, msg) + return [prb], [msg] + + rv = [nodes.Text(name + " ")] + tokens = re.findall(r"(#?\d+)|([^\d#]+)", text) + for ticket, noise in tokens: + if ticket: + num = int(ticket.replace("#", "")) + + url = cfg.ticket_url % num + roles.set_classes(options) + node = nodes.reference( + ticket, utils.unescape(ticket), refuri=url, **options + ) + + rv.append(node) + + else: + assert noise + rv.append(nodes.Text(noise)) + + return rv, [] + + +def setup(app): + app.add_config_value("ticket_url", None, "env") + app.add_role("ticket", ticket_role) + app.add_role("tickets", ticket_role) diff --git a/docs/news.rst b/docs/news.rst new file mode 100644 index 0000000..46dfbe7 --- /dev/null +++ b/docs/news.rst @@ -0,0 +1,285 @@ +.. currentmodule:: psycopg + +.. index:: + single: Release notes + single: News + +``psycopg`` release notes +========================= + +Current release +--------------- + +Psycopg 3.1.7 +^^^^^^^^^^^^^ + +- Fix server-side cursors using row factories (:ticket:`#464`). + + +Psycopg 3.1.6 +^^^^^^^^^^^^^ + +- Fix `cursor.copy()` with cursors using row factories (:ticket:`#460`). + + +Psycopg 3.1.5 +^^^^^^^^^^^^^ + +- Fix array loading slowness compared to psycopg2 (:ticket:`#359`). +- Improve performance around network communication (:ticket:`#414`). +- Return `!bytes` instead of `!memoryview` from `pq.Encoding` methods + (:ticket:`#422`). +- Fix `Cursor.rownumber` to return `!None` when the result has no row to fetch + (:ticket:`#437`). +- Avoid error in Pyright caused by aliasing `!TypeAlias` (:ticket:`#439`). +- Fix `Copy.set_types()` used with `varchar` and `name` types (:ticket:`#452`). +- Improve performance using :ref:`row-factories` (:ticket:`#457`). + + +Psycopg 3.1.4 +^^^^^^^^^^^^^ + +- Include :ref:`error classes <sqlstate-exceptions>` defined in PostgreSQL 15. +- Add support for Python 3.11 (:ticket:`#305`). +- Build binary packages with libpq from PostgreSQL 15.0. + + +Psycopg 3.1.3 +^^^^^^^^^^^^^ + +- Restore the state of the connection if `Cursor.stream()` is terminated + prematurely (:ticket:`#382`). +- Fix regression introduced in 3.1 with different named tuples mangling rules + for non-ascii attribute names (:ticket:`#386`). +- Fix handling of queries with escaped percent signs (``%%``) in `ClientCursor` + (:ticket:`#399`). +- Fix possible duplicated BEGIN statements emitted in pipeline mode + (:ticket:`#401`). + + +Psycopg 3.1.2 +^^^^^^^^^^^^^ + +- Fix handling of certain invalid time zones causing problems on Windows + (:ticket:`#371`). +- Fix segfault occurring when a loader fails initialization (:ticket:`#372`). +- Fix invalid SAVEPOINT issued when entering `Connection.transaction()` within + a pipeline using an implicit transaction (:ticket:`#374`). +- Fix queries with repeated named parameters in `ClientCursor` (:ticket:`#378`). +- Distribute macOS arm64 (Apple M1) binary packages (:ticket:`#344`). + + +Psycopg 3.1.1 +^^^^^^^^^^^^^ + +- Work around broken Homebrew installation of the libpq in a non-standard path + (:ticket:`#364`) +- Fix possible "unrecognized service" error in async connection when no port + is specified (:ticket:`#366`). + + +Psycopg 3.1 +----------- + +- Add :ref:`Pipeline mode <pipeline-mode>` (:ticket:`#74`). +- Add :ref:`client-side-binding-cursors` (:ticket:`#101`). +- Add `CockroachDB <https://www.cockroachlabs.com/>`__ support in `psycopg.crdb` + (:ticket:`#313`). +- Add :ref:`Two-Phase Commit <two-phase-commit>` support (:ticket:`#72`). +- Add :ref:`adapt-enum` (:ticket:`#274`). +- Add ``returning`` parameter to `~Cursor.executemany()` to retrieve query + results (:ticket:`#164`). +- `~Cursor.executemany()` performance improved by using batch mode internally + (:ticket:`#145`). +- Add parameters to `~Cursor.copy()`. +- Add :ref:`COPY Writer objects <copy-writers>`. +- Resolve domain names asynchronously in `AsyncConnection.connect()` + (:ticket:`#259`). +- Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`). +- Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`). +- Add ``cursor_factory`` parameter to `Connection` init. +- Add `Error.pgconn` and `Error.pgresult` attributes (:ticket:`#242`). +- Restrict queries to be `~typing.LiteralString` as per :pep:`675` + (:ticket:`#323`). +- Add explicit type cast to values converted by `sql.Literal` (:ticket:`#205`). +- Drop support for Python 3.6. + + +Psycopg 3.0.17 +^^^^^^^^^^^^^^ + +- Fix segfaults on fork on some Linux systems using `ctypes` implementation + (:ticket:`#300`). +- Load bytea as bytes, not memoryview, using `ctypes` implementation. + + +Psycopg 3.0.16 +^^^^^^^^^^^^^^ + +- Fix missing `~Cursor.rowcount` after SHOW (:ticket:`#343`). +- Add scripts to build macOS arm64 packages (:ticket:`#162`). + + +Psycopg 3.0.15 +^^^^^^^^^^^^^^ + +- Fix wrong escaping of unprintable chars in COPY (nonetheless correctly + interpreted by PostgreSQL). +- Restore the connection to usable state after an error in `~Cursor.stream()`. +- Raise `DataError` instead of `OverflowError` loading binary intervals + out-of-range. +- Distribute ``manylinux2014`` wheel packages (:ticket:`#124`). + + +Psycopg 3.0.14 +^^^^^^^^^^^^^^ + +- Raise `DataError` dumping arrays of mixed types (:ticket:`#301`). +- Fix handling of incorrect server results, with blank sqlstate (:ticket:`#303`). +- Fix bad Float4 conversion on ppc64le/musllinux (:ticket:`#304`). + + +Psycopg 3.0.13 +^^^^^^^^^^^^^^ + +- Fix `Cursor.stream()` slowness (:ticket:`#286`). +- Fix oid for lists of integers, which might cause the server choosing + bad plans (:ticket:`#293`). +- Make `Connection.cancel()` on a closed connection a no-op instead of an + error. + + +Psycopg 3.0.12 +^^^^^^^^^^^^^^ + +- Allow `bytearray`/`memoryview` data too as `Copy.write()` input + (:ticket:`#254`). +- Fix dumping `~enum.IntEnum` in text mode, Python implementation. + + +Psycopg 3.0.11 +^^^^^^^^^^^^^^ + +- Fix `DataError` loading arrays with dimensions information (:ticket:`#253`). +- Fix hanging during COPY in case of memory error (:ticket:`#255`). +- Fix error propagation from COPY worker thread (mentioned in :ticket:`#255`). + + +Psycopg 3.0.10 +^^^^^^^^^^^^^^ + +- Leave the connection in working state after interrupting a query with Ctrl-C + (:ticket:`#231`). +- Fix `Cursor.description` after a COPY ... TO STDOUT operation + (:ticket:`#235`). +- Fix building on FreeBSD and likely other BSD flavours (:ticket:`#241`). + + +Psycopg 3.0.9 +^^^^^^^^^^^^^ + +- Set `Error.sqlstate` when an unknown code is received (:ticket:`#225`). +- Add the `!tzdata` package as a dependency on Windows in order to handle time + zones (:ticket:`#223`). + + +Psycopg 3.0.8 +^^^^^^^^^^^^^ + +- Decode connection errors in the ``client_encoding`` specified in the + connection string, if available (:ticket:`#194`). +- Fix possible warnings in objects deletion on interpreter shutdown + (:ticket:`#198`). +- Don't leave connections in ACTIVE state in case of error during COPY ... TO + STDOUT (:ticket:`#203`). + + +Psycopg 3.0.7 +^^^^^^^^^^^^^ + +- Fix crash in `~Cursor.executemany()` with no input sequence + (:ticket:`#179`). +- Fix wrong `~Cursor.rowcount` after an `~Cursor.executemany()` returning no + rows (:ticket:`#178`). + + +Psycopg 3.0.6 +^^^^^^^^^^^^^ + +- Allow to use `Cursor.description` if the connection is closed + (:ticket:`#172`). +- Don't raise exceptions on `ServerCursor.close()` if the connection is closed + (:ticket:`#173`). +- Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`). +- Raise `ProgrammingError` if out-of-order exit from transaction contexts is + detected (:tickets:`#176, #177`). +- Add `!CHECK_STANDBY` value to `~pq.ConnStatus` enum. + + +Psycopg 3.0.5 +^^^^^^^^^^^^^ + +- Fix possible "Too many open files" OS error, reported on macOS but possible + on other platforms too (:ticket:`#158`). +- Don't clobber exceptions if a transaction block exit with error and rollback + fails (:ticket:`#165`). + + +Psycopg 3.0.4 +^^^^^^^^^^^^^ + +- Allow to use the module with strict strings comparison (:ticket:`#147`). +- Fix segfault on Python 3.6 running in ``-W error`` mode, related to + `!backport.zoneinfo` `ticket #109 + <https://github.com/pganssle/zoneinfo/issues/109>`__. +- Build binary package with libpq versions not affected by `CVE-2021-23222 + <https://www.postgresql.org/support/security/CVE-2021-23222/>`__ + (:ticket:`#149`). + + +Psycopg 3.0.3 +^^^^^^^^^^^^^ + +- Release musllinux binary packages, compatible with Alpine Linux + (:ticket:`#141`). +- Reduce size of binary package by stripping debug symbols (:ticket:`#142`). +- Include typing information in the `!psycopg_binary` package. + + +Psycopg 3.0.2 +^^^^^^^^^^^^^ + +- Fix type hint for `sql.SQL.join()` (:ticket:`#127`). +- Fix type hint for `Connection.notifies()` (:ticket:`#128`). +- Fix call to `MultiRange.__setitem__()` with a non-iterable value and a + slice, now raising a `TypeError` (:ticket:`#129`). +- Fix disable cursors methods after close() (:ticket:`#125`). + + +Psycopg 3.0.1 +^^^^^^^^^^^^^ + +- Fix use of the wrong dumper reusing cursors with the same query but different + parameter types (:ticket:`#112`). + + +Psycopg 3.0 +----------- + +First stable release. Changed from 3.0b1: + +- Add :ref:`adapt-shapely` (:ticket:`#80`). +- Add :ref:`adapt-multirange` (:ticket:`#75`). +- Add `pq.__build_version__` constant. +- Don't use the extended protocol with COPY, (:tickets:`#78, #82`). +- Add ``context`` parameter to `~Connection.connect()` (:ticket:`#83`). +- Fix selection of dumper by oid after `~Copy.set_types()`. +- Drop `!Connection.client_encoding`. Use `ConnectionInfo.encoding` to read + it, and a :sql:`SET` statement to change it. +- Add binary packages for Python 3.10 (:ticket:`#103`). + + +Psycopg 3.0b1 +^^^^^^^^^^^^^ + +- First public release on PyPI. diff --git a/docs/news_pool.rst b/docs/news_pool.rst new file mode 100644 index 0000000..7f212e0 --- /dev/null +++ b/docs/news_pool.rst @@ -0,0 +1,81 @@ +.. currentmodule:: psycopg_pool + +.. index:: + single: Release notes + single: News + +``psycopg_pool`` release notes +============================== + +Current release +--------------- + +psycopg_pool 3.1.5 +^^^^^^^^^^^^^^^^^^ + +- Make sure that `!ConnectionPool.check()` refills an empty pool + (:ticket:`#438`). +- Avoid error in Pyright caused by aliasing `!TypeAlias` (:ticket:`#439`). + + +psycopg_pool 3.1.4 +^^^^^^^^^^^^^^^^^^ + +- Fix async pool exhausting connections, happening if the pool is created + before the event loop is started (:ticket:`#219`). + + +psycopg_pool 3.1.3 +^^^^^^^^^^^^^^^^^^ + +- Add support for Python 3.11 (:ticket:`#305`). + + +psycopg_pool 3.1.2 +^^^^^^^^^^^^^^^^^^ + +- Fix possible failure to reconnect after losing connection from the server + (:ticket:`#370`). + + +psycopg_pool 3.1.1 +^^^^^^^^^^^^^^^^^^ + +- Fix race condition on pool creation which might result in the pool not + filling (:ticket:`#230`). + + +psycopg_pool 3.1.0 +------------------ + +- Add :ref:`null-pool` (:ticket:`#148`). +- Add `ConnectionPool.open()` and ``open`` parameter to the pool init + (:ticket:`#151`). +- Drop support for Python 3.6. + + +psycopg_pool 3.0.3 +^^^^^^^^^^^^^^^^^^ + +- Raise `!ValueError` if `ConnectionPool` `!min_size` and `!max_size` are both + set to 0 (instead of hanging). +- Raise `PoolClosed` calling `~ConnectionPool.wait()` on a closed pool. + + +psycopg_pool 3.0.2 +^^^^^^^^^^^^^^^^^^ + +- Remove dependency on the internal `!psycopg._compat` module. + + +psycopg_pool 3.0.1 +^^^^^^^^^^^^^^^^^^ + +- Don't leave connections idle in transaction after calling + `~ConnectionPool.check()` (:ticket:`#144`). + + +psycopg_pool 3.0 +---------------- + +- First release on PyPI. diff --git a/docs/pictures/adapt.drawio b/docs/pictures/adapt.drawio new file mode 100644 index 0000000..75f61ed --- /dev/null +++ b/docs/pictures/adapt.drawio @@ -0,0 +1,107 @@ +<mxfile host="Electron" modified="2021-07-12T13:26:05.192Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/14.6.13 Chrome/89.0.4389.128 Electron/12.0.7 Safari/537.36" etag="kKU1DyIkJcQFc1Rxt__U" compressed="false" version="14.6.13" type="device"> + <diagram id="THISp3X85jFCtBEH0bao" name="Page-1"> + <mxGraphModel dx="675" dy="400" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0"> + <root> + <mxCell id="0" /> + <mxCell id="1" parent="0" /> + <mxCell id="uy255Msn6vtulWmyCIR1-12" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontFamily=Courier New;exitX=1;exitY=0.5;exitDx=0;exitDy=0;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-29" target="uy255Msn6vtulWmyCIR1-11"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="280" y="210" as="sourcePoint" /> + </mxGeometry> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Courier New;exitX=1;exitY=0.5;exitDx=0;exitDy=0;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-30" target="uy255Msn6vtulWmyCIR1-14"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="280" y="320" as="sourcePoint" /> + </mxGeometry> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-39" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-11" target="uy255Msn6vtulWmyCIR1-14"> + <mxGeometry relative="1" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-11" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="330" y="185" width="80" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-40" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-14" target="uy255Msn6vtulWmyCIR1-27"> + <mxGeometry relative="1" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-14" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="330" y="285" width="80" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontFamily=Courier New;exitX=1;exitY=0.5;exitDx=0;exitDy=0;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-31" target="uy255Msn6vtulWmyCIR1-27"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="280" y="440" as="sourcePoint" /> + </mxGeometry> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-18" value=".cursor()" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="220" y="220" width="80" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-26" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-19" target="uy255Msn6vtulWmyCIR1-25"> + <mxGeometry relative="1" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-34" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-19" target="uy255Msn6vtulWmyCIR1-29"> + <mxGeometry relative="1" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-19" value="<b>psycopg</b><br><font face="Helvetica">module</font>" style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="160" y="75" width="120" height="50" as="geometry" /> + </mxCell> + <UserObject label=".connect()" link="../api/connections.html" id="uy255Msn6vtulWmyCIR1-20"> + <mxCell style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="220" y="125" width="80" height="20" as="geometry" /> + </mxCell> + </UserObject> + <mxCell id="uy255Msn6vtulWmyCIR1-21" value=".execute()" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="220" y="320" width="80" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-37" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-25" target="uy255Msn6vtulWmyCIR1-11"> + <mxGeometry relative="1" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-25" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="330" y="90" width="80" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-27" value=".adapters" style="text;html=1;strokeColor=none;fillColor=none;align=left;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="330" y="385" width="80" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-35" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-29" target="uy255Msn6vtulWmyCIR1-30"> + <mxGeometry relative="1" as="geometry" /> + </mxCell> + <UserObject label="<b>Connection</b><br><font face="Helvetica">object</font>" link="../api/connections.html" id="uy255Msn6vtulWmyCIR1-29"> + <mxCell style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="160" y="170" width="120" height="50" as="geometry" /> + </mxCell> + </UserObject> + <mxCell id="uy255Msn6vtulWmyCIR1-36" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Courier New;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-30" target="uy255Msn6vtulWmyCIR1-31"> + <mxGeometry relative="1" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-30" value="<b>Cursor</b><br><font face="Helvetica">object</font>" style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="160" y="270" width="120" height="50" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-31" value="<b>Transformer</b><br><font face="Helvetica">object</font>" style="rounded=1;whiteSpace=wrap;html=1;fontFamily=Courier New;" vertex="1" parent="1"> + <mxGeometry x="160" y="370" width="120" height="50" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-46" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Helvetica;endArrow=none;endFill=0;dashed=1;dashPattern=1 1;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-41"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="310" y="100" as="targetPoint" /> + </mxGeometry> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-41" value="Has a" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Helvetica;" vertex="1" parent="1"> + <mxGeometry x="300" y="55" width="40" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-45" style="rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Helvetica;endArrow=none;endFill=0;dashed=1;dashPattern=1 1;startSize=4;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-42"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="220" y="150" as="targetPoint" /> + </mxGeometry> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-42" value="Create" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Helvetica;" vertex="1" parent="1"> + <mxGeometry x="150" y="130" width="40" height="20" as="geometry" /> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-47" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fontFamily=Helvetica;endArrow=none;endFill=0;dashed=1;dashPattern=1 1;" edge="1" parent="1" source="uy255Msn6vtulWmyCIR1-43"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="370" y="150" as="targetPoint" /> + </mxGeometry> + </mxCell> + <mxCell id="uy255Msn6vtulWmyCIR1-43" value="Copy" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontFamily=Helvetica;" vertex="1" parent="1"> + <mxGeometry x="394" y="130" width="40" height="20" as="geometry" /> + </mxCell> + </root> + </mxGraphModel> + </diagram> +</mxfile> diff --git a/docs/pictures/adapt.svg b/docs/pictures/adapt.svg new file mode 100644 index 0000000..2c39755 --- /dev/null +++ b/docs/pictures/adapt.svg @@ -0,0 +1,3 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> +<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" width="285px" height="366px" viewBox="-0.5 -0.5 285 366" style="background-color: rgb(255, 255, 255);"><defs/><g><path d="M 130 140 L 173.63 140" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 140 L 171.88 143.5 L 173.63 140 L 171.88 136.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><path d="M 130 240 L 173.63 240" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 240 L 171.88 243.5 L 173.63 240 L 171.88 236.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><path d="M 220 150 L 220 223.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 220 228.88 L 216.5 221.88 L 220 223.63 L 223.5 221.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="180" y="130" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 140px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="144" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><path d="M 220 250 L 220 323.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 220 328.88 L 216.5 321.88 L 220 323.63 L 223.5 321.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="180" y="230" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 240px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="244" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><path d="M 130 340 L 173.63 340" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 340 L 171.88 343.5 L 173.63 340 L 171.88 336.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="70" y="165" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 175px; margin-left: 72px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.cursor()</div></div></div></foreignObject><text x="72" y="179" fill="#000000" font-family="Courier New" font-size="12px">.cursor()</text></switch></g><path d="M 130 45 L 173.63 45" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 178.88 45 L 171.88 48.5 L 173.63 45 L 171.88 41.5 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><path d="M 70 70 L 70 108.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 70 113.88 L 66.5 106.88 L 70 108.63 L 73.5 106.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="10" y="20" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 45px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>psycopg</b><br /><font face="Helvetica">module</font></div></div></div></foreignObject><text x="70" y="49" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">psycopg...</text></switch></g><a xlink:href="../api/connections.html"><rect x="70" y="70" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 80px; margin-left: 72px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.connect()</div></div></div></foreignObject><text x="72" y="84" fill="#000000" font-family="Courier New" font-size="12px">.connect()</text></switch></g></a><rect x="70" y="265" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 275px; margin-left: 72px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.execute()</div></div></div></foreignObject><text x="72" y="279" fill="#000000" font-family="Courier New" font-size="12px">.execute()</text></switch></g><path d="M 220 55 L 220 123.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 220 128.88 L 216.5 121.88 L 220 123.63 L 223.5 121.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="180" y="35" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 45px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="49" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><rect x="180" y="330" width="80" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 78px; height: 1px; padding-top: 340px; margin-left: 182px;"><div style="box-sizing: border-box; font-size: 0; text-align: left; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">.adapters</div></div></div></foreignObject><text x="182" y="344" fill="#000000" font-family="Courier New" font-size="12px">.adapters</text></switch></g><path d="M 70 165 L 70 208.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 70 213.88 L 66.5 206.88 L 70 208.63 L 73.5 206.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><a xlink:href="../api/connections.html"><rect x="10" y="115" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 140px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>Connection</b><br /><font face="Helvetica">object</font></div></div></div></foreignObject><text x="70" y="144" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">Connection...</text></switch></g></a><path d="M 70 265 L 70 308.63" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke"/><path d="M 70 313.88 L 66.5 306.88 L 70 308.63 L 73.5 306.88 Z" fill="#000000" stroke="#000000" stroke-miterlimit="10" pointer-events="all"/><rect x="10" y="215" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 240px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>Cursor</b><br /><font face="Helvetica">object</font></div></div></div></foreignObject><text x="70" y="244" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">Cursor...</text></switch></g><rect x="10" y="315" width="120" height="50" rx="7.5" ry="7.5" fill="#ffffff" stroke="#000000" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 340px; margin-left: 11px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Courier New; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; "><b>Transformer</b><br /><font face="Helvetica">object</font></div></div></div></foreignObject><text x="70" y="344" fill="#000000" font-family="Courier New" font-size="12px" text-anchor="middle">Transformer...</text></switch></g><path d="M 167.14 20 L 160 45" fill="none" stroke="#000000" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke"/><rect x="150" y="0" width="40" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 10px; margin-left: 151px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">Has a</div></div></div></foreignObject><text x="170" y="14" fill="#000000" font-family="Helvetica" font-size="12px" text-anchor="middle">Has a</text></switch></g><path d="M 40 89 L 70 95" fill="none" stroke="#000000" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke"/><rect x="0" y="75" width="40" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 85px; margin-left: 1px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">Create</div></div></div></foreignObject><text x="20" y="89" fill="#000000" font-family="Helvetica" font-size="12px" text-anchor="middle">Create</text></switch></g><path d="M 244 89.55 L 220 95" fill="none" stroke="#000000" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke"/><rect x="244" y="75" width="40" height="20" fill="none" stroke="none" pointer-events="all"/><g transform="translate(-0.5 -0.5)"><switch><foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"><div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 85px; margin-left: 245px;"><div style="box-sizing: border-box; font-size: 0; text-align: center; "><div style="display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; ">Copy</div></div></div></foreignObject><text x="264" y="89" fill="#000000" font-family="Helvetica" font-size="12px" text-anchor="middle">Copy</text></switch></g></g><switch><g requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"/><a transform="translate(0,-5)" xlink:href="https://www.diagrams.net/doc/faq/svg-export-text-problems" target="_blank"><text text-anchor="middle" font-size="10px" x="50%" y="100%">Viewer does not support full SVG 1.1</text></a></switch></svg>
\ No newline at end of file diff --git a/docs/release.rst b/docs/release.rst new file mode 100644 index 0000000..8fcadaf --- /dev/null +++ b/docs/release.rst @@ -0,0 +1,39 @@ +:orphan: + +How to make a psycopg release +============================= + +- Change version number in: + + - ``psycopg_c/psycopg_c/version.py`` + - ``psycopg/psycopg/version.py`` + - ``psycopg_pool/psycopg_pool/version.py`` + +- Change docs/news.rst to drop the "unreleased" mark from the version + +- Push to GitHub to run `the tests workflow`__. + + .. __: https://github.com/psycopg/psycopg/actions/workflows/tests.yml + +- Build the packages by triggering manually the `Build packages workflow`__. + + .. __: https://github.com/psycopg/psycopg/actions/workflows/packages.yml + +- If all went fine, create a tag named after the version:: + + git tag -a -s 3.0.dev1 + git push --tags + +- Download the ``artifacts.zip`` package from the last Packages workflow run. + +- Unpack the packages locally:: + + mkdir tmp + cd tmp + unzip ~/Downloads/artifact.zip + +- If the package is a testing one, upload it on TestPyPI with:: + + $ twine upload -s -r testpypi * + +- If the package is stable, omit ``-r testpypi``. diff --git a/psycopg/.flake8 b/psycopg/.flake8 new file mode 100644 index 0000000..67fb024 --- /dev/null +++ b/psycopg/.flake8 @@ -0,0 +1,6 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 +per-file-ignores = + # Autogenerated section + psycopg/errors.py: E125, E128, E302 diff --git a/psycopg/LICENSE.txt b/psycopg/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/psycopg/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/psycopg/README.rst b/psycopg/README.rst new file mode 100644 index 0000000..45eeac3 --- /dev/null +++ b/psycopg/README.rst @@ -0,0 +1,31 @@ +Psycopg 3: PostgreSQL database adapter for Python +================================================= + +Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python. + +This distribution contains the pure Python package ``psycopg``. + + +Installation +------------ + +In short, run the following:: + + pip install --upgrade pip # to upgrade pip + pip install "psycopg[binary,pool]" # to install package and dependencies + +If something goes wrong, and for more information about installation, please +check out the `Installation documentation`__. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html# + + +Hacking +------- + +For development information check out `the project readme`__. + +.. __: https://github.com/psycopg/psycopg#readme + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py new file mode 100644 index 0000000..baadf30 --- /dev/null +++ b/psycopg/psycopg/__init__.py @@ -0,0 +1,110 @@ +""" +psycopg -- PostgreSQL database adapter for Python +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from . import pq # noqa: F401 import early to stabilize side effects +from . import types +from . import postgres +from ._tpc import Xid +from .copy import Copy, AsyncCopy +from ._enums import IsolationLevel +from .cursor import Cursor +from .errors import Warning, Error, InterfaceError, DatabaseError +from .errors import DataError, OperationalError, IntegrityError +from .errors import InternalError, ProgrammingError, NotSupportedError +from ._column import Column +from .conninfo import ConnectionInfo +from ._pipeline import Pipeline, AsyncPipeline +from .connection import BaseConnection, Connection, Notify +from .transaction import Rollback, Transaction, AsyncTransaction +from .cursor_async import AsyncCursor +from .server_cursor import AsyncServerCursor, ServerCursor +from .client_cursor import AsyncClientCursor, ClientCursor +from .connection_async import AsyncConnection + +from . import dbapi20 +from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING +from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks +from .dbapi20 import Timestamp, TimestampFromTicks + +from .version import __version__ as __version__ # noqa: F401 + +# Set the logger to a quiet default, can be enabled if needed +logger = logging.getLogger("psycopg") +if logger.level == logging.NOTSET: + logger.setLevel(logging.WARNING) + +# DBAPI compliance +connect = Connection.connect +apilevel = "2.0" +threadsafety = 2 +paramstyle = "pyformat" + +# register default adapters for PostgreSQL +adapters = postgres.adapters # exposed by the package +postgres.register_default_adapters(adapters) + +# After the default ones, because these can deal with the bytea oid better +dbapi20.register_dbapi20_adapters(adapters) + +# Must come after all the types have been registered +types.array.register_all_arrays(adapters) + +# Note: defining the exported methods helps both Sphynx in documenting that +# this is the canonical place to obtain them and should be used by MyPy too, +# so that function signatures are consistent with the documentation. +__all__ = [ + "AsyncClientCursor", + "AsyncConnection", + "AsyncCopy", + "AsyncCursor", + "AsyncPipeline", + "AsyncServerCursor", + "AsyncTransaction", + "BaseConnection", + "ClientCursor", + "Column", + "Connection", + "ConnectionInfo", + "Copy", + "Cursor", + "IsolationLevel", + "Notify", + "Pipeline", + "Rollback", + "ServerCursor", + "Transaction", + "Xid", + # DBAPI exports + "connect", + "apilevel", + "threadsafety", + "paramstyle", + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", + # DBAPI type constructors and singletons + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "DATETIME", + "NUMBER", + "ROWID", + "STRING", +] diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py new file mode 100644 index 0000000..a3a6ef8 --- /dev/null +++ b/psycopg/psycopg/_adapters_map.py @@ -0,0 +1,289 @@ +""" +Mapping from types/oids to Dumpers/Loaders +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import cast, TYPE_CHECKING + +from . import pq +from . import errors as e +from .abc import Dumper, Loader +from ._enums import PyFormat as PyFormat +from ._cmodule import _psycopg +from ._typeinfo import TypesRegistry + +if TYPE_CHECKING: + from .connection import BaseConnection + +RV = TypeVar("RV") + + +class AdaptersMap: + r""" + Establish how types should be converted between Python and PostgreSQL in + an `~psycopg.abc.AdaptContext`. + + `!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to + define how Python types are converted to PostgreSQL, and maps OIDs to + `~psycopg.adapt.Loader` classes to establish how query results are + converted to Python. + + Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how + types are converted in that context, exposed as the + `~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows + to customise adaptation in a context without changing separated contexts. + + When a context is created from another context (for instance when a + `~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's + `!adapters` are used as template for the child's `!adapters`, so that every + cursor created from the same connection use the connection's types + configuration, but separate connections have independent mappings. + + Once created, `!AdaptersMap` are independent. This means that objects + already created are not affected if a wider scope (e.g. the global one) is + changed. + + The connections adapters are initialised using a global `!AdptersMap` + template, exposed as `psycopg.adapters`: changing such mapping allows to + customise the type mapping for every connections created afterwards. + + The object can start empty or copy from another object of the same class. + Copies are copy-on-write: if the maps are updated make a copy. This way + extending e.g. global map by a connection or a connection map from a cursor + is cheap: a copy is only made on customisation. + """ + + __module__ = "psycopg.adapt" + + types: TypesRegistry + + _dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]] + _dumpers_by_oid: List[Dict[int, Type[Dumper]]] + _loaders: List[Dict[int, Type[Loader]]] + + # Record if a dumper or loader has an optimised version. + _optimised: Dict[type, type] = {} + + def __init__( + self, + template: Optional["AdaptersMap"] = None, + types: Optional[TypesRegistry] = None, + ): + if template: + self._dumpers = template._dumpers.copy() + self._own_dumpers = _dumpers_shared.copy() + template._own_dumpers = _dumpers_shared.copy() + + self._dumpers_by_oid = template._dumpers_by_oid[:] + self._own_dumpers_by_oid = [False, False] + template._own_dumpers_by_oid = [False, False] + + self._loaders = template._loaders[:] + self._own_loaders = [False, False] + template._own_loaders = [False, False] + + self.types = TypesRegistry(template.types) + + else: + self._dumpers = {fmt: {} for fmt in PyFormat} + self._own_dumpers = _dumpers_owned.copy() + + self._dumpers_by_oid = [{}, {}] + self._own_dumpers_by_oid = [True, True] + + self._loaders = [{}, {}] + self._own_loaders = [True, True] + + self.types = types or TypesRegistry() + + # implement the AdaptContext protocol too + @property + def adapters(self) -> "AdaptersMap": + return self + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + return None + + def register_dumper( + self, cls: Union[type, str, None], dumper: Type[Dumper] + ) -> None: + """ + Configure the context to use `!dumper` to convert objects of type `!cls`. + + If two dumpers with different `~Dumper.format` are registered for the + same type, the last one registered will be chosen when the query + doesn't specify a format (i.e. when the value is used with a ``%s`` + "`~PyFormat.AUTO`" placeholder). + + :param cls: The type to manage. + :param dumper: The dumper to register for `!cls`. + + If `!cls` is specified as string it will be lazy-loaded, so that it + will be possible to register it without importing it before. In this + case it should be the fully qualified name of the object (e.g. + ``"uuid.UUID"``). + + If `!cls` is None, only use the dumper when looking up using + `get_dumper_by_oid()`, which happens when we know the Postgres type to + adapt to, but not the Python type that will be adapted (e.g. in COPY + after using `~psycopg.Copy.set_types()`). + + """ + if not (cls is None or isinstance(cls, (str, type))): + raise TypeError( + f"dumpers should be registered on classes, got {cls} instead" + ) + + if _psycopg: + dumper = self._get_optimised(dumper) + + # Register the dumper both as its format and as auto + # so that the last dumper registered is used in auto (%s) format + if cls: + for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO): + if not self._own_dumpers[fmt]: + self._dumpers[fmt] = self._dumpers[fmt].copy() + self._own_dumpers[fmt] = True + + self._dumpers[fmt][cls] = dumper + + # Register the dumper by oid, if the oid of the dumper is fixed + if dumper.oid: + if not self._own_dumpers_by_oid[dumper.format]: + self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[ + dumper.format + ].copy() + self._own_dumpers_by_oid[dumper.format] = True + + self._dumpers_by_oid[dumper.format][dumper.oid] = dumper + + def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None: + """ + Configure the context to use `!loader` to convert data of oid `!oid`. + + :param oid: The PostgreSQL OID or type name to manage. + :param loader: The loar to register for `!oid`. + + If `oid` is specified as string, it refers to a type name, which is + looked up in the `types` registry. ` + + """ + if isinstance(oid, str): + oid = self.types[oid].oid + if not isinstance(oid, int): + raise TypeError(f"loaders should be registered on oid, got {oid} instead") + + if _psycopg: + loader = self._get_optimised(loader) + + fmt = loader.format + if not self._own_loaders[fmt]: + self._loaders[fmt] = self._loaders[fmt].copy() + self._own_loaders[fmt] = True + + self._loaders[fmt][oid] = loader + + def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]: + """ + Return the dumper class for the given type and format. + + Raise `~psycopg.ProgrammingError` if a class is not available. + + :param cls: The class to adapt. + :param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`, + use the last one of the dumpers registered on `!cls`. + """ + try: + dmap = self._dumpers[format] + except KeyError: + raise ValueError(f"bad dumper format: {format}") + + # Look for the right class, including looking at superclasses + for scls in cls.__mro__: + if scls in dmap: + return dmap[scls] + + # If the adapter is not found, look for its name as a string + fqn = scls.__module__ + "." + scls.__qualname__ + if fqn in dmap: + # Replace the class name with the class itself + d = dmap[scls] = dmap.pop(fqn) + return d + + raise e.ProgrammingError( + f"cannot adapt type {cls.__name__!r} using placeholder '%{format}'" + f" (format: {PyFormat(format).name})" + ) + + def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]: + """ + Return the dumper class for the given oid and format. + + Raise `~psycopg.ProgrammingError` if a class is not available. + + :param oid: The oid of the type to dump to. + :param format: The format to dump to. + """ + try: + dmap = self._dumpers_by_oid[format] + except KeyError: + raise ValueError(f"bad dumper format: {format}") + + try: + return dmap[oid] + except KeyError: + info = self.types.get(oid) + if info: + msg = ( + f"cannot find a dumper for type {info.name} (oid {oid})" + f" format {pq.Format(format).name}" + ) + else: + msg = ( + f"cannot find a dumper for unknown type with oid {oid}" + f" format {pq.Format(format).name}" + ) + raise e.ProgrammingError(msg) + + def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]: + """ + Return the loader class for the given oid and format. + + Return `!None` if not found. + + :param oid: The oid of the type to load. + :param format: The format to load from. + """ + return self._loaders[format].get(oid) + + @classmethod + def _get_optimised(self, cls: Type[RV]) -> Type[RV]: + """Return the optimised version of a Dumper or Loader class. + + Return the input class itself if there is no optimised version. + """ + try: + return self._optimised[cls] + except KeyError: + pass + + # Check if the class comes from psycopg.types and there is a class + # with the same name in psycopg_c._psycopg. + from psycopg import types + + if cls.__module__.startswith(types.__name__): + new = cast(Type[RV], getattr(_psycopg, cls.__name__, None)) + if new: + self._optimised[cls] = new + return new + + self._optimised[cls] = cls + return cls + + +# Micro-optimization: copying these objects is faster than creating new dicts +_dumpers_owned = dict.fromkeys(PyFormat, True) +_dumpers_shared = dict.fromkeys(PyFormat, False) diff --git a/psycopg/psycopg/_cmodule.py b/psycopg/psycopg/_cmodule.py new file mode 100644 index 0000000..288ef1b --- /dev/null +++ b/psycopg/psycopg/_cmodule.py @@ -0,0 +1,24 @@ +""" +Simplify access to the _psycopg module +""" + +# Copyright (C) 2021 The Psycopg Team + +from typing import Optional + +from . import pq + +__version__: Optional[str] = None + +# Note: "c" must the first attempt so that mypy associates the variable the +# right module interface. It will not result Optional, but hey. +if pq.__impl__ == "c": + from psycopg_c import _psycopg as _psycopg + from psycopg_c import __version__ as __version__ # noqa: F401 +elif pq.__impl__ == "binary": + from psycopg_binary import _psycopg as _psycopg # type: ignore + from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401 +elif pq.__impl__ == "python": + _psycopg = None # type: ignore +else: + raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}") diff --git a/psycopg/psycopg/_column.py b/psycopg/psycopg/_column.py new file mode 100644 index 0000000..9e4e735 --- /dev/null +++ b/psycopg/psycopg/_column.py @@ -0,0 +1,143 @@ +""" +The Column object in Cursor.description +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING +from operator import attrgetter + +if TYPE_CHECKING: + from .cursor import BaseCursor + + +class ColumnData(NamedTuple): + ftype: int + fmod: int + fsize: int + + +class Column(Sequence[Any]): + + __module__ = "psycopg" + + def __init__(self, cursor: "BaseCursor[Any, Any]", index: int): + res = cursor.pgresult + assert res + + fname = res.fname(index) + if fname: + self._name = fname.decode(cursor._encoding) + else: + # COPY_OUT results have columns but no name + self._name = f"column_{index + 1}" + + self._data = ColumnData( + ftype=res.ftype(index), + fmod=res.fmod(index), + fsize=res.fsize(index), + ) + self._type = cursor.adapters.types.get(self._data.ftype) + + _attrs = tuple( + attrgetter(attr) + for attr in """ + name type_code display_size internal_size precision scale null_ok + """.split() + ) + + def __repr__(self) -> str: + return ( + f"<Column {self.name!r}," + f" type: {self._type_display()} (oid: {self.type_code})>" + ) + + def __len__(self) -> int: + return 7 + + def _type_display(self) -> str: + parts = [] + parts.append(self._type.name if self._type else str(self.type_code)) + + mod1 = self.precision + if mod1 is None: + mod1 = self.display_size + if mod1: + parts.append(f"({mod1}") + if self.scale: + parts.append(f", {self.scale}") + parts.append(")") + + if self._type and self.type_code == self._type.array_oid: + parts.append("[]") + + return "".join(parts) + + def __getitem__(self, index: Any) -> Any: + if isinstance(index, slice): + return tuple(getter(self) for getter in self._attrs[index]) + else: + return self._attrs[index](self) + + @property + def name(self) -> str: + """The name of the column.""" + return self._name + + @property + def type_code(self) -> int: + """The numeric OID of the column.""" + return self._data.ftype + + @property + def display_size(self) -> Optional[int]: + """The field size, for :sql:`varchar(n)`, None otherwise.""" + if not self._type: + return None + + if self._type.name in ("varchar", "char"): + fmod = self._data.fmod + if fmod >= 0: + return fmod - 4 + + return None + + @property + def internal_size(self) -> Optional[int]: + """The internal field size for fixed-size types, None otherwise.""" + fsize = self._data.fsize + return fsize if fsize >= 0 else None + + @property + def precision(self) -> Optional[int]: + """The number of digits for fixed precision types.""" + if not self._type: + return None + + dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval") + if self._type.name == "numeric": + fmod = self._data.fmod + if fmod >= 0: + return fmod >> 16 + + elif self._type.name in dttypes: + fmod = self._data.fmod + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def scale(self) -> Optional[int]: + """The number of digits after the decimal point if available.""" + if self._type and self._type.name == "numeric": + fmod = self._data.fmod - 4 + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def null_ok(self) -> Optional[bool]: + """Always `!None`""" + return None diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py new file mode 100644 index 0000000..7dbae79 --- /dev/null +++ b/psycopg/psycopg/_compat.py @@ -0,0 +1,72 @@ +""" +compatibility functions for different Python versions +""" + +# Copyright (C) 2021 The Psycopg Team + +import sys +import asyncio +from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar + +# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it. +# For this raisin it must be imported directly from typing_extension where used. +# See https://github.com/microsoft/pyright/issues/4197 +from typing_extensions import TypeAlias + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +T = TypeVar("T") +FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]] + +if sys.version_info >= (3, 8): + create_task = asyncio.create_task + from math import prod + +else: + + def create_task( + coro: FutureT[T], name: Optional[str] = None + ) -> "asyncio.Future[T]": + return asyncio.create_task(coro) + + from functools import reduce + + def prod(seq: Sequence[int]) -> int: + return reduce(int.__mul__, seq, 1) + + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo + from functools import cache + from collections import Counter, deque as Deque +else: + from typing import Counter, Deque + from functools import lru_cache + from backports.zoneinfo import ZoneInfo + + cache = lru_cache(maxsize=None) + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +if sys.version_info >= (3, 11): + from typing import LiteralString +else: + from typing_extensions import LiteralString + +__all__ = [ + "Counter", + "Deque", + "LiteralString", + "Protocol", + "TypeGuard", + "ZoneInfo", + "cache", + "create_task", + "prod", +] diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py new file mode 100644 index 0000000..1e146ba --- /dev/null +++ b/psycopg/psycopg/_dns.py @@ -0,0 +1,223 @@ +# type: ignore # dnspython is currently optional and mypy fails if missing +""" +DNS query support +""" + +# Copyright (C) 2021 The Psycopg Team + +import os +import re +import warnings +from random import randint +from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence +from typing import TYPE_CHECKING +from collections import defaultdict + +try: + from dns.resolver import Resolver, Cache + from dns.asyncresolver import Resolver as AsyncResolver + from dns.exception import DNSException +except ImportError: + raise ImportError( + "the module psycopg._dns requires the package 'dnspython' installed" + ) + +from . import errors as e +from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_ + +if TYPE_CHECKING: + from dns.rdtypes.IN.SRV import SRV + +resolver = Resolver() +resolver.cache = Cache() + +async_resolver = AsyncResolver() +async_resolver.cache = Cache() + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + .. deprecated:: 3.1 + The use of this function is not necessary anymore, because + `psycopg.AsyncConnection.connect()` performs non-blocking name + resolution automatically. + """ + warnings.warn( + "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore", + DeprecationWarning, + ) + return await resolve_hostaddr_async_(params) + + +def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]: + """Apply SRV DNS lookup as defined in :RFC:`2782`.""" + return Rfc2782Resolver().resolve(params) + + +async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]: + """Async equivalent of `resolve_srv()`.""" + return await Rfc2782Resolver().resolve_async(params) + + +class HostPort(NamedTuple): + host: str + port: str + totry: bool = False + target: Optional[str] = None + + +class Rfc2782Resolver: + """Implement SRV RR Resolution as per RFC 2782 + + The class is organised to minimise code duplication between the sync and + the async paths. + """ + + re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)") + + def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Update the parameters host and port after SRV lookup.""" + attempts = self._get_attempts(params) + if not attempts: + return params + + hps = [] + for hp in attempts: + if hp.totry: + hps.extend(self._resolve_srv(hp)) + else: + hps.append(hp) + + return self._return_params(params, hps) + + async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Update the parameters host and port after SRV lookup.""" + attempts = self._get_attempts(params) + if not attempts: + return params + + hps = [] + for hp in attempts: + if hp.totry: + hps.extend(await self._resolve_srv_async(hp)) + else: + hps.append(hp) + + return self._return_params(params, hps) + + def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]: + """ + Return the list of host, and for each host if SRV lookup must be tried. + + Return an empty list if no lookup is requested. + """ + # If hostaddr is defined don't do any resolution. + if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")): + return [] + + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + ports_in = port_arg.split(",") + + if len(ports_in) == 1: + # If only one port is specified, it applies to all the hosts. + ports_in *= len(hosts_in) + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" + ) + + out = [] + srv_found = False + for host, port in zip(hosts_in, ports_in): + m = self.re_srv_rr.match(host) + if m or port.lower() == "srv": + srv_found = True + target = m.group("target") if m else None + hp = HostPort(host=host, port=port, totry=True, target=target) + else: + hp = HostPort(host=host, port=port) + out.append(hp) + + return out if srv_found else [] + + def _resolve_srv(self, hp: HostPort) -> List[HostPort]: + try: + ans = resolver.resolve(hp.host, "SRV") + except DNSException: + ans = () + return self._get_solved_entries(hp, ans) + + async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]: + try: + ans = await async_resolver.resolve(hp.host, "SRV") + except DNSException: + ans = () + return self._get_solved_entries(hp, ans) + + def _get_solved_entries( + self, hp: HostPort, entries: "Sequence[SRV]" + ) -> List[HostPort]: + if not entries: + # No SRV entry found. Delegate the libpq a QNAME=target lookup + if hp.target and hp.port.lower() != "srv": + return [HostPort(host=hp.target, port=hp.port)] + else: + return [] + + # If there is precisely one SRV RR, and its Target is "." (the root + # domain), abort. + if len(entries) == 1 and str(entries[0].target) == ".": + return [] + + return [ + HostPort(host=str(entry.target).rstrip("."), port=str(entry.port)) + for entry in self.sort_rfc2782(entries) + ] + + def _return_params( + self, params: Dict[str, Any], hps: List[HostPort] + ) -> Dict[str, Any]: + if not hps: + # Nothing found, we ended up with an empty list + raise e.OperationalError("no host found after SRV RR lookup") + + out = params.copy() + out["host"] = ",".join(hp.host for hp in hps) + out["port"] = ",".join(str(hp.port) for hp in hps) + return out + + def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]": + """ + Implement the priority/weight ordering defined in RFC 2782. + """ + # Divide the entries by priority: + priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list) + out: "List[SRV]" = [] + for entry in ans: + priorities[entry.priority].append(entry) + + for pri, entries in sorted(priorities.items()): + if len(entries) == 1: + out.append(entries[0]) + continue + + entries.sort(key=lambda ent: ent.weight) + total_weight = sum(ent.weight for ent in entries) + while entries: + r = randint(0, total_weight) + csum = 0 + for i, ent in enumerate(entries): + csum += ent.weight + if csum >= r: + break + out.append(ent) + total_weight -= ent.weight + del entries[i] + + return out diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py new file mode 100644 index 0000000..c584b26 --- /dev/null +++ b/psycopg/psycopg/_encodings.py @@ -0,0 +1,170 @@ +""" +Mappings between PostgreSQL and Python encodings. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import string +import codecs +from typing import Any, Dict, Optional, TYPE_CHECKING + +from .pq._enums import ConnStatus +from .errors import NotSupportedError +from ._compat import cache + +if TYPE_CHECKING: + from .pq.abc import PGconn + from .connection import BaseConnection + +OK = ConnStatus.OK + + +_py_codecs = { + "BIG5": "big5", + "EUC_CN": "gb2312", + "EUC_JIS_2004": "euc_jis_2004", + "EUC_JP": "euc_jp", + "EUC_KR": "euc_kr", + # "EUC_TW": not available in Python + "GB18030": "gb18030", + "GBK": "gbk", + "ISO_8859_5": "iso8859-5", + "ISO_8859_6": "iso8859-6", + "ISO_8859_7": "iso8859-7", + "ISO_8859_8": "iso8859-8", + "JOHAB": "johab", + "KOI8R": "koi8-r", + "KOI8U": "koi8-u", + "LATIN1": "iso8859-1", + "LATIN10": "iso8859-16", + "LATIN2": "iso8859-2", + "LATIN3": "iso8859-3", + "LATIN4": "iso8859-4", + "LATIN5": "iso8859-9", + "LATIN6": "iso8859-10", + "LATIN7": "iso8859-13", + "LATIN8": "iso8859-14", + "LATIN9": "iso8859-15", + # "MULE_INTERNAL": not available in Python + "SHIFT_JIS_2004": "shift_jis_2004", + "SJIS": "shift_jis", + # this actually means no encoding, see PostgreSQL docs + # it is special-cased by the text loader. + "SQL_ASCII": "ascii", + "UHC": "cp949", + "UTF8": "utf-8", + "WIN1250": "cp1250", + "WIN1251": "cp1251", + "WIN1252": "cp1252", + "WIN1253": "cp1253", + "WIN1254": "cp1254", + "WIN1255": "cp1255", + "WIN1256": "cp1256", + "WIN1257": "cp1257", + "WIN1258": "cp1258", + "WIN866": "cp866", + "WIN874": "cp874", +} + +py_codecs: Dict[bytes, str] = {} +py_codecs.update((k.encode(), v) for k, v in _py_codecs.items()) + +# Add an alias without underscore, for lenient lookups +py_codecs.update( + (k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k +) + +pg_codecs = {v: k.encode() for k, v in _py_codecs.items()} + + +def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str: + """ + Return the Python encoding name of a psycopg connection. + + Default to utf8 if the connection has no encoding info. + """ + if not conn or conn.closed: + return "utf-8" + + pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8" + return pg2pyenc(pgenc) + + +def pgconn_encoding(pgconn: "PGconn") -> str: + """ + Return the Python encoding name of a libpq connection. + + Default to utf8 if the connection has no encoding info. + """ + if pgconn.status != OK: + return "utf-8" + + pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8" + return pg2pyenc(pgenc) + + +def conninfo_encoding(conninfo: str) -> str: + """ + Return the Python encoding name passed in a conninfo string. Default to utf8. + + Because the input is likely to come from the user and not normalised by the + server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars). + """ + from .conninfo import conninfo_to_dict + + params = conninfo_to_dict(conninfo) + pgenc = params.get("client_encoding") + if pgenc: + try: + return pg2pyenc(pgenc.encode()) + except NotSupportedError: + pass + + return "utf-8" + + +@cache +def py2pgenc(name: str) -> bytes: + """Convert a Python encoding name to PostgreSQL encoding name. + + Raise LookupError if the Python encoding is unknown. + """ + return pg_codecs[codecs.lookup(name).name] + + +@cache +def pg2pyenc(name: bytes) -> str: + """Convert a Python encoding name to PostgreSQL encoding name. + + Raise NotSupportedError if the PostgreSQL encoding is not supported by + Python. + """ + try: + return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()] + except KeyError: + sname = name.decode("utf8", "replace") + raise NotSupportedError(f"codec not available in Python: {sname!r}") + + +def _as_python_identifier(s: str, prefix: str = "f") -> str: + """ + Reduce a string to a valid Python identifier. + + Replace all non-valid chars with '_' and prefix the value with `!prefix` if + the first letter is an '_'. + """ + if not s.isidentifier(): + if s[0] in "1234567890": + s = prefix + s + if not s.isidentifier(): + s = _re_clean.sub("_", s) + # namedtuple fields cannot start with underscore. So... + if s[0] == "_": + s = prefix + s + return s + + +_re_clean = re.compile( + f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]" +) diff --git a/psycopg/psycopg/_enums.py b/psycopg/psycopg/_enums.py new file mode 100644 index 0000000..a7cb78d --- /dev/null +++ b/psycopg/psycopg/_enums.py @@ -0,0 +1,79 @@ +""" +Enum values for psycopg + +These values are defined by us and are not necessarily dependent on +libpq-defined enums. +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import Enum, IntEnum +from selectors import EVENT_READ, EVENT_WRITE + +from . import pq + + +class Wait(IntEnum): + R = EVENT_READ + W = EVENT_WRITE + RW = EVENT_READ | EVENT_WRITE + + +class Ready(IntEnum): + R = EVENT_READ + W = EVENT_WRITE + RW = EVENT_READ | EVENT_WRITE + + +class PyFormat(str, Enum): + """ + Enum representing the format wanted for a query argument. + + The value `AUTO` allows psycopg to choose the best format for a certain + parameter. + """ + + __module__ = "psycopg.adapt" + + AUTO = "s" + """Automatically chosen (``%s`` placeholder).""" + TEXT = "t" + """Text parameter (``%t`` placeholder).""" + BINARY = "b" + """Binary parameter (``%b`` placeholder).""" + + @classmethod + def from_pq(cls, fmt: pq.Format) -> "PyFormat": + return _pg2py[fmt] + + @classmethod + def as_pq(cls, fmt: "PyFormat") -> pq.Format: + return _py2pg[fmt] + + +class IsolationLevel(IntEnum): + """ + Enum representing the isolation level for a transaction. + """ + + __module__ = "psycopg" + + READ_UNCOMMITTED = 1 + """:sql:`READ UNCOMMITTED` isolation level.""" + READ_COMMITTED = 2 + """:sql:`READ COMMITTED` isolation level.""" + REPEATABLE_READ = 3 + """:sql:`REPEATABLE READ` isolation level.""" + SERIALIZABLE = 4 + """:sql:`SERIALIZABLE` isolation level.""" + + +_py2pg = { + PyFormat.TEXT: pq.Format.TEXT, + PyFormat.BINARY: pq.Format.BINARY, +} + +_pg2py = { + pq.Format.TEXT: PyFormat.TEXT, + pq.Format.BINARY: PyFormat.BINARY, +} diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py new file mode 100644 index 0000000..c818d86 --- /dev/null +++ b/psycopg/psycopg/_pipeline.py @@ -0,0 +1,288 @@ +""" +commands pipeline management +""" + +# Copyright (C) 2021 The Psycopg Team + +import logging +from types import TracebackType +from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from .abc import PipelineCommand, PQGen +from ._compat import Deque +from ._encodings import pgconn_encoding +from ._preparing import Key, Prepare +from .generators import pipeline_communicate, fetch_many, send + +if TYPE_CHECKING: + from .pq.abc import PGresult + from .cursor import BaseCursor + from .connection import BaseConnection, Connection + from .connection_async import AsyncConnection + + +PendingResult: TypeAlias = Union[ + None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]] +] + +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR +PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED +BAD = pq.ConnStatus.BAD + +ACTIVE = pq.TransactionStatus.ACTIVE + +logger = logging.getLogger("psycopg") + + +class BasePipeline: + + command_queue: Deque[PipelineCommand] + result_queue: Deque[PendingResult] + _is_supported: Optional[bool] = None + + def __init__(self, conn: "BaseConnection[Any]") -> None: + self._conn = conn + self.pgconn = conn.pgconn + self.command_queue = Deque[PipelineCommand]() + self.result_queue = Deque[PendingResult]() + self.level = 0 + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._conn.pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + @property + def status(self) -> pq.PipelineStatus: + return pq.PipelineStatus(self.pgconn.pipeline_status) + + @classmethod + def is_supported(cls) -> bool: + """Return `!True` if the psycopg libpq wrapper supports pipeline mode.""" + if BasePipeline._is_supported is None: + BasePipeline._is_supported = not cls._not_supported_reason() + return BasePipeline._is_supported + + @classmethod + def _not_supported_reason(cls) -> str: + """Return the reason why the pipeline mode is not supported. + + Return an empty string if pipeline mode is supported. + """ + # Support only depends on the libpq functions available in the pq + # wrapper, not on the database version. + if pq.version() < 140000: + return ( + f"libpq too old {pq.version()};" + " v14 or greater required for pipeline mode" + ) + + if pq.__build_version__ < 140000: + return ( + f"libpq too old: module built for {pq.__build_version__};" + " v14 or greater required for pipeline mode" + ) + + return "" + + def _enter_gen(self) -> PQGen[None]: + if not self.is_supported(): + raise e.NotSupportedError( + f"pipeline mode not supported: {self._not_supported_reason()}" + ) + if self.level == 0: + self.pgconn.enter_pipeline_mode() + elif self.command_queue or self.pgconn.transaction_status == ACTIVE: + # Nested pipeline case. + # Transaction might be ACTIVE when the pipeline uses an "implicit + # transaction", typically in autocommit mode. But when entering a + # Psycopg transaction(), we expect the IDLE state. By sync()-ing, + # we make sure all previous commands are completed and the + # transaction gets back to IDLE. + yield from self._sync_gen() + self.level += 1 + + def _exit(self, exc: Optional[BaseException]) -> None: + self.level -= 1 + if self.level == 0 and self.pgconn.status != BAD: + try: + self.pgconn.exit_pipeline_mode() + except e.OperationalError as exc2: + # Notice that this error might be pretty irrecoverable. It + # happens on COPY, for instance: even if sync succeeds, exiting + # fails with "cannot exit pipeline mode with uncollected results" + if exc: + logger.warning("error ignored exiting %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + + def _sync_gen(self) -> PQGen[None]: + self._enqueue_sync() + yield from self._communicate_gen() + yield from self._fetch_gen(flush=False) + + def _exit_gen(self) -> PQGen[None]: + """ + Exit current pipeline by sending a Sync and fetch back all remaining results. + """ + try: + self._enqueue_sync() + yield from self._communicate_gen() + finally: + # No need to force flush since we emitted a sync just before. + yield from self._fetch_gen(flush=False) + + def _communicate_gen(self) -> PQGen[None]: + """Communicate with pipeline to send commands and possibly fetch + results, which are then processed. + """ + fetched = yield from pipeline_communicate(self.pgconn, self.command_queue) + to_process = [(self.result_queue.popleft(), results) for results in fetched] + for queued, results in to_process: + self._process_results(queued, results) + + def _fetch_gen(self, *, flush: bool) -> PQGen[None]: + """Fetch available results from the connection and process them with + pipeline queued items. + + If 'flush' is True, a PQsendFlushRequest() is issued in order to make + sure results can be fetched. Otherwise, the caller may emit a + PQpipelineSync() call to ensure the output buffer gets flushed before + fetching. + """ + if not self.result_queue: + return + + if flush: + self.pgconn.send_flush_request() + yield from send(self.pgconn) + + to_process = [] + while self.result_queue: + results = yield from fetch_many(self.pgconn) + if not results: + # No more results to fetch, but there may still be pending + # commands. + break + queued = self.result_queue.popleft() + to_process.append((queued, results)) + + for queued, results in to_process: + self._process_results(queued, results) + + def _process_results( + self, queued: PendingResult, results: List["PGresult"] + ) -> None: + """Process a results set fetched from the current pipeline. + + This matches 'results' with its respective element in the pipeline + queue. For commands (None value in the pipeline queue), results are + checked directly. For prepare statement creation requests, update the + cache. Otherwise, results are attached to their respective cursor. + """ + if queued is None: + (result,) = results + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + elif result.status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + else: + cursor, prepinfo = queued + cursor._set_results_from_pipeline(results) + if prepinfo: + key, prep, name = prepinfo + # Update the prepare state of the query. + cursor._conn._prepared.validate(key, prep, name, results) + + def _enqueue_sync(self) -> None: + """Enqueue a PQpipelineSync() command.""" + self.command_queue.append(self.pgconn.pipeline_sync) + self.result_queue.append(None) + + +class Pipeline(BasePipeline): + """Handler for connection in pipeline mode.""" + + __module__ = "psycopg" + _conn: "Connection[Any]" + _Self = TypeVar("_Self", bound="Pipeline") + + def __init__(self, conn: "Connection[Any]") -> None: + super().__init__(conn) + + def sync(self) -> None: + """Sync the pipeline, send any pending command and receive and process + all available results. + """ + try: + with self._conn.lock: + self._conn.wait(self._sync_gen()) + except e.Error as ex: + raise ex.with_traceback(None) + + def __enter__(self: _Self) -> _Self: + with self._conn.lock: + self._conn.wait(self._enter_gen()) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + try: + with self._conn.lock: + self._conn.wait(self._exit_gen()) + except Exception as exc2: + # Don't clobber an exception raised in the block with this one + if exc_val: + logger.warning("error ignored terminating %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + finally: + self._exit(exc_val) + + +class AsyncPipeline(BasePipeline): + """Handler for async connection in pipeline mode.""" + + __module__ = "psycopg" + _conn: "AsyncConnection[Any]" + _Self = TypeVar("_Self", bound="AsyncPipeline") + + def __init__(self, conn: "AsyncConnection[Any]") -> None: + super().__init__(conn) + + async def sync(self) -> None: + try: + async with self._conn.lock: + await self._conn.wait(self._sync_gen()) + except e.Error as ex: + raise ex.with_traceback(None) + + async def __aenter__(self: _Self) -> _Self: + async with self._conn.lock: + await self._conn.wait(self._enter_gen()) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + try: + async with self._conn.lock: + await self._conn.wait(self._exit_gen()) + except Exception as exc2: + # Don't clobber an exception raised in the block with this one + if exc_val: + logger.warning("error ignored terminating %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + finally: + self._exit(exc_val) diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py new file mode 100644 index 0000000..f60c0cb --- /dev/null +++ b/psycopg/psycopg/_preparing.py @@ -0,0 +1,198 @@ +""" +Support for prepared statements +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import IntEnum, auto +from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING +from collections import OrderedDict +from typing_extensions import TypeAlias + +from . import pq +from ._compat import Deque +from ._queries import PostgresQuery + +if TYPE_CHECKING: + from .pq.abc import PGresult + +Key: TypeAlias = Tuple[bytes, Tuple[int, ...]] + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK + + +class Prepare(IntEnum): + NO = auto() + YES = auto() + SHOULD = auto() + + +class PrepareManager: + # Number of times a query is executed before it is prepared. + prepare_threshold: Optional[int] = 5 + + # Maximum number of prepared statements on the connection. + prepared_max: int = 100 + + def __init__(self) -> None: + # Map (query, types) to the number of times the query was seen. + self._counts: OrderedDict[Key, int] = OrderedDict() + + # Map (query, types) to the name of the statement if prepared. + self._names: OrderedDict[Key, bytes] = OrderedDict() + + # Counter to generate prepared statements names + self._prepared_idx = 0 + + self._maint_commands = Deque[bytes]() + + @staticmethod + def key(query: PostgresQuery) -> Key: + return (query.query, query.types) + + def get( + self, query: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + """ + Check if a query is prepared, tell back whether to prepare it. + """ + if prepare is False or self.prepare_threshold is None: + # The user doesn't want this query to be prepared + return Prepare.NO, b"" + + key = self.key(query) + name = self._names.get(key) + if name: + # The query was already prepared in this session + return Prepare.YES, name + + count = self._counts.get(key, 0) + if count >= self.prepare_threshold or prepare: + # The query has been executed enough times and needs to be prepared + name = f"_pg3_{self._prepared_idx}".encode() + self._prepared_idx += 1 + return Prepare.SHOULD, name + else: + # The query is not to be prepared yet + return Prepare.NO, b"" + + def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool: + """Check if we need to discard our entire state: it should happen on + rollback or on dropping objects, because the same object may get + recreated and postgres would fail internal lookups. + """ + if self._names or prep == Prepare.SHOULD: + for result in results: + if result.status != COMMAND_OK: + continue + cmdstat = result.command_status + if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"): + return self.clear() + return False + + @staticmethod + def _check_results(results: Sequence["PGresult"]) -> bool: + """Return False if 'results' are invalid for prepared statement cache.""" + if len(results) != 1: + # We cannot prepare a multiple statement + return False + + status = results[0].status + if COMMAND_OK != status != TUPLES_OK: + # We don't prepare failed queries or other weird results + return False + + return True + + def _rotate(self) -> None: + """Evict an old value from the cache. + + If it was prepared, deallocate it. Do it only once: if the cache was + resized, deallocate gradually. + """ + if len(self._counts) > self.prepared_max: + self._counts.popitem(last=False) + + if len(self._names) > self.prepared_max: + name = self._names.popitem(last=False)[1] + self._maint_commands.append(b"DEALLOCATE " + name) + + def maybe_add_to_cache( + self, query: PostgresQuery, prep: Prepare, name: bytes + ) -> Optional[Key]: + """Handle 'query' for possible addition to the cache. + + If a new entry has been added, return its key. Return None otherwise + (meaning the query is already in cache or cache is not enabled). + + Note: This method is only called in pipeline mode. + """ + # don't do anything if prepared statements are disabled + if self.prepare_threshold is None: + return None + + key = self.key(query) + if key in self._counts: + if prep is Prepare.SHOULD: + del self._counts[key] + self._names[key] = name + else: + self._counts[key] += 1 + self._counts.move_to_end(key) + return None + + elif key in self._names: + self._names.move_to_end(key) + return None + + else: + if prep is Prepare.SHOULD: + self._names[key] = name + else: + self._counts[key] = 1 + return key + + def validate( + self, + key: Key, + prep: Prepare, + name: bytes, + results: Sequence["PGresult"], + ) -> None: + """Validate cached entry with 'key' by checking query 'results'. + + Possibly return a command to perform maintenance on database side. + + Note: this method is only called in pipeline mode. + """ + if self._should_discard(prep, results): + return + + if not self._check_results(results): + self._names.pop(key, None) + self._counts.pop(key, None) + else: + self._rotate() + + def clear(self) -> bool: + """Clear the cache of the maintenance commands. + + Clear the internal state and prepare a command to clear the state of + the server. + """ + self._counts.clear() + if self._names: + self._names.clear() + self._maint_commands.clear() + self._maint_commands.append(b"DEALLOCATE ALL") + return True + else: + return False + + def get_maintenance_commands(self) -> Iterator[bytes]: + """ + Iterate over the commands needed to align the server state to our state + """ + while self._maint_commands: + yield self._maint_commands.popleft() diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py new file mode 100644 index 0000000..2a7554c --- /dev/null +++ b/psycopg/psycopg/_queries.py @@ -0,0 +1,375 @@ +""" +Utility module to manipulate queries +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional +from typing import Sequence, Tuple, Union, TYPE_CHECKING +from functools import lru_cache + +from . import pq +from . import errors as e +from .sql import Composable +from .abc import Buffer, Query, Params +from ._enums import PyFormat +from ._encodings import conn_encoding + +if TYPE_CHECKING: + from .abc import Transformer + + +class QueryPart(NamedTuple): + pre: bytes + item: Union[int, str] + format: PyFormat + + +class PostgresQuery: + """ + Helper to convert a Python query and parameters into Postgres format. + """ + + __slots__ = """ + query params types formats + _tx _want_formats _parts _encoding _order + """.split() + + def __init__(self, transformer: "Transformer"): + self._tx = transformer + + self.params: Optional[Sequence[Optional[Buffer]]] = None + # these are tuples so they can be used as keys e.g. in prepared stmts + self.types: Tuple[int, ...] = () + + # The format requested by the user and the ones to really pass Postgres + self._want_formats: Optional[List[PyFormat]] = None + self.formats: Optional[Sequence[pq.Format]] = None + + self._encoding = conn_encoding(transformer.connection) + self._parts: List[QueryPart] + self.query = b"" + self._order: Optional[List[str]] = None + + def convert(self, query: Query, vars: Optional[Params]) -> None: + """ + Set up the query and parameters to convert. + + The results of this function can be obtained accessing the object + attributes (`query`, `params`, `types`, `formats`). + """ + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + if vars is not None: + ( + self.query, + self._want_formats, + self._order, + self._parts, + ) = _query2pg(bquery, self._encoding) + else: + self.query = bquery + self._want_formats = self._order = None + + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + """ + Process a new set of variables on the query processed by `convert()`. + + This method updates `params` and `types`. + """ + if vars is not None: + params = _validate_and_reorder_params(self._parts, vars, self._order) + assert self._want_formats is not None + self.params = self._tx.dump_sequence(params, self._want_formats) + self.types = self._tx.types or () + self.formats = self._tx.formats + else: + self.params = None + self.types = () + self.formats = None + + +class PostgresClientQuery(PostgresQuery): + """ + PostgresQuery subclass merging query and arguments client-side. + """ + + __slots__ = ("template",) + + def convert(self, query: Query, vars: Optional[Params]) -> None: + """ + Set up the query and parameters to convert. + + The results of this function can be obtained accessing the object + attributes (`query`, `params`, `types`, `formats`). + """ + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + if vars is not None: + (self.template, self._order, self._parts) = _query2pg_client( + bquery, self._encoding + ) + else: + self.query = bquery + self._order = None + + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + """ + Process a new set of variables on the query processed by `convert()`. + + This method updates `params` and `types`. + """ + if vars is not None: + params = _validate_and_reorder_params(self._parts, vars, self._order) + self.params = tuple( + self._tx.as_literal(p) if p is not None else b"NULL" for p in params + ) + self.query = self.template % self.params + else: + self.params = None + + +@lru_cache() +def _query2pg( + query: bytes, encoding: str +) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into something Postgres understands. + + - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres + format (``$1``, ``$2``) + - placeholders can be %s, %t, or %b (auto, text or binary) + - return ``query`` (bytes), ``formats`` (list of formats) ``order`` + (sequence of names used in the query, in the position they appear) + ``parts`` (splits of queries and placeholders). + """ + parts = _split_query(query, encoding) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + formats = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"$%d" % (part.item + 1)) + formats.append(part.format) + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"$%d" % (len(seen) + 1) + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + formats.append(part.format) + else: + if seen[part.item][1] != part.format: + raise e.ProgrammingError( + f"placeholder '{part.item}' cannot have different formats" + ) + chunks.append(seen[part.item][0]) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), formats, order, parts + + +@lru_cache() +def _query2pg_client( + query: bytes, encoding: str +) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into a template to perform client-side binding + """ + parts = _split_query(query, encoding, collapse_double_percent=False) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"%s") + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"%s" + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + else: + chunks.append(seen[part.item][0]) + order.append(part.item) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), order, parts + + +def _validate_and_reorder_params( + parts: List[QueryPart], vars: Params, order: Optional[List[str]] +) -> Sequence[Any]: + """ + Verify the compatibility between a query and a set of params. + """ + # Try concrete types, then abstract types + t = type(vars) + if t is list or t is tuple: + sequence = True + elif t is dict: + sequence = False + elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): + sequence = True + elif isinstance(vars, Mapping): + sequence = False + else: + raise TypeError( + "query parameters should be a sequence or a mapping," + f" got {type(vars).__name__}" + ) + + if sequence: + if len(vars) != len(parts) - 1: + raise e.ProgrammingError( + f"the query has {len(parts) - 1} placeholders but" + f" {len(vars)} parameters were passed" + ) + if vars and not isinstance(parts[0].item, int): + raise TypeError("named placeholders require a mapping of parameters") + return vars # type: ignore[return-value] + + else: + if vars and len(parts) > 1 and not isinstance(parts[0][1], str): + raise TypeError( + "positional placeholders (%s) require a sequence of parameters" + ) + try: + return [vars[item] for item in order or ()] # type: ignore[call-overload] + except KeyError: + raise e.ProgrammingError( + "query parameter missing:" + f" {', '.join(sorted(i for i in order or () if i not in vars))}" + ) + + +_re_placeholder = re.compile( + rb"""(?x) + % # a literal % + (?: + (?: + \( ([^)]+) \) # or a name in (braces) + . # followed by a format + ) + | + (?:.) # or any char, really + ) + """ +) + + +def _split_query( + query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True +) -> List[QueryPart]: + parts: List[Tuple[bytes, Optional[Match[bytes]]]] = [] + cur = 0 + + # pairs [(fragment, match], with the last match None + m = None + for m in _re_placeholder.finditer(query): + pre = query[cur : m.span(0)[0]] + parts.append((pre, m)) + cur = m.span(0)[1] + if m: + parts.append((query[cur:], None)) + else: + parts.append((query, None)) + + rv = [] + + # drop the "%%", validate + i = 0 + phtype = None + while i < len(parts): + pre, m = parts[i] + if m is None: + # last part + rv.append(QueryPart(pre, 0, PyFormat.AUTO)) + break + + ph = m.group(0) + if ph == b"%%": + # unescape '%%' to '%' if necessary, then merge the parts + if collapse_double_percent: + ph = b"%" + pre1, m1 = parts[i + 1] + parts[i + 1] = (pre + ph + pre1, m1) + del parts[i] + continue + + if ph == b"%(": + raise e.ProgrammingError( + "incomplete placeholder:" + f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'" + ) + elif ph == b"% ": + # explicit messasge for a typical error + raise e.ProgrammingError( + "incomplete placeholder: '%'; if you want to use '%' as an" + " operator you can double it up, i.e. use '%%'" + ) + elif ph[-1:] not in b"sbt": + raise e.ProgrammingError( + "only '%s', '%b', '%t' are allowed as placeholders, got" + f" '{m.group(0).decode(encoding)}'" + ) + + # Index or name + item: Union[int, str] + item = m.group(1).decode(encoding) if m.group(1) else i + + if not phtype: + phtype = type(item) + elif phtype is not type(item): + raise e.ProgrammingError( + "positional and named placeholders cannot be mixed" + ) + + format = _ph_to_fmt[ph[-1:]] + rv.append(QueryPart(pre, item, format)) + i += 1 + + return rv + + +_ph_to_fmt = { + b"s": PyFormat.AUTO, + b"t": PyFormat.TEXT, + b"b": PyFormat.BINARY, +} diff --git a/psycopg/psycopg/_struct.py b/psycopg/psycopg/_struct.py new file mode 100644 index 0000000..28a6084 --- /dev/null +++ b/psycopg/psycopg/_struct.py @@ -0,0 +1,57 @@ +""" +Utility functions to deal with binary structs. +""" + +# Copyright (C) 2020 The Psycopg Team + +import struct +from typing import Callable, cast, Optional, Tuple +from typing_extensions import TypeAlias + +from .abc import Buffer +from . import errors as e +from ._compat import Protocol + +PackInt: TypeAlias = Callable[[int], bytes] +UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]] +PackFloat: TypeAlias = Callable[[float], bytes] +UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]] + + +class UnpackLen(Protocol): + def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: + ... + + +pack_int2 = cast(PackInt, struct.Struct("!h").pack) +pack_uint2 = cast(PackInt, struct.Struct("!H").pack) +pack_int4 = cast(PackInt, struct.Struct("!i").pack) +pack_uint4 = cast(PackInt, struct.Struct("!I").pack) +pack_int8 = cast(PackInt, struct.Struct("!q").pack) +pack_float4 = cast(PackFloat, struct.Struct("!f").pack) +pack_float8 = cast(PackFloat, struct.Struct("!d").pack) + +unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack) +unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack) +unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack) +unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack) +unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack) +unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack) +unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack) + +_struct_len = struct.Struct("!i") +pack_len = cast(Callable[[int], bytes], _struct_len.pack) +unpack_len = cast(UnpackLen, _struct_len.unpack_from) + + +def pack_float4_bug_304(x: float) -> bytes: + raise e.InterfaceError( + "cannot dump Float4: Python affected by bug #304. Note that the psycopg-c" + " and psycopg-binary packages are not affected by this issue." + " See https://github.com/psycopg/psycopg/issues/304" + ) + + +# If issue #304 is detected, raise an error instead of dumping wrong data. +if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"): + pack_float4 = pack_float4_bug_304 diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py new file mode 100644 index 0000000..3528188 --- /dev/null +++ b/psycopg/psycopg/_tpc.py @@ -0,0 +1,116 @@ +""" +psycopg two-phase commit support +""" + +# Copyright (C) 2021 The Psycopg Team + +import re +import datetime as dt +from base64 import b64encode, b64decode +from typing import Optional, Union +from dataclasses import dataclass, replace + +_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$") + + +@dataclass(frozen=True) +class Xid: + """A two-phase commit transaction identifier. + + The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`, + `bqual`). + + """ + + format_id: Optional[int] + gtrid: str + bqual: Optional[str] + prepared: Optional[dt.datetime] = None + owner: Optional[str] = None + database: Optional[str] = None + + @classmethod + def from_string(cls, s: str) -> "Xid": + """Try to parse an XA triple from the string. + + This may fail for several reasons. In such case return an unparsed Xid. + """ + try: + return cls._parse_string(s) + except Exception: + return Xid(None, s, None) + + def __str__(self) -> str: + return self._as_tid() + + def __len__(self) -> int: + return 3 + + def __getitem__(self, index: int) -> Union[int, str, None]: + return (self.format_id, self.gtrid, self.bqual)[index] + + @classmethod + def _parse_string(cls, s: str) -> "Xid": + m = _re_xid.match(s) + if not m: + raise ValueError("bad Xid format") + + format_id = int(m.group(1)) + gtrid = b64decode(m.group(2)).decode() + bqual = b64decode(m.group(3)).decode() + return cls.from_parts(format_id, gtrid, bqual) + + @classmethod + def from_parts( + cls, format_id: Optional[int], gtrid: str, bqual: Optional[str] + ) -> "Xid": + if format_id is not None: + if bqual is None: + raise TypeError("if format_id is specified, bqual must be too") + if not 0 <= format_id < 0x80000000: + raise ValueError("format_id must be a non-negative 32-bit integer") + if len(bqual) > 64: + raise ValueError("bqual must be not longer than 64 chars") + if len(gtrid) > 64: + raise ValueError("gtrid must be not longer than 64 chars") + + elif bqual is None: + raise TypeError("if format_id is None, bqual must be None too") + + return Xid(format_id, gtrid, bqual) + + def _as_tid(self) -> str: + """ + Return the PostgreSQL transaction_id for this XA xid. + + PostgreSQL wants just a string, while the DBAPI supports the XA + standard and thus a triple. We use the same conversion algorithm + implemented by JDBC in order to allow some form of interoperation. + + see also: the pgjdbc implementation + http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/ + postgresql/xa/RecoveredXid.java?rev=1.2 + """ + if self.format_id is None or self.bqual is None: + # Unparsed xid: return the gtrid. + return self.gtrid + + # XA xid: mash together the components. + egtrid = b64encode(self.gtrid.encode()).decode() + ebqual = b64encode(self.bqual.encode()).decode() + + return f"{self.format_id}_{egtrid}_{ebqual}" + + @classmethod + def _get_recover_query(cls) -> str: + return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts" + + @classmethod + def _from_record( + cls, gid: str, prepared: dt.datetime, owner: str, database: str + ) -> "Xid": + xid = Xid.from_string(gid) + return replace(xid, prepared=prepared, owner=owner, database=database) + + +Xid.__module__ = "psycopg" diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py new file mode 100644 index 0000000..19bd6ae --- /dev/null +++ b/psycopg/psycopg/_transform.py @@ -0,0 +1,350 @@ +""" +Helper object to transform values between Python and PostgreSQL +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import DefaultDict, TYPE_CHECKING +from collections import defaultdict +from typing_extensions import TypeAlias + +from . import pq +from . import postgres +from . import errors as e +from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType +from .rows import Row, RowMaker +from .postgres import INVALID_OID, TEXT_OID +from ._encodings import pgconn_encoding + +if TYPE_CHECKING: + from .abc import Dumper, Loader + from .adapt import AdaptersMap + from .pq.abc import PGresult + from .connection import BaseConnection + +DumperCache: TypeAlias = Dict[DumperKey, "Dumper"] +OidDumperCache: TypeAlias = Dict[int, "Dumper"] +LoaderCache: TypeAlias = Dict[int, "Loader"] + +TEXT = pq.Format.TEXT +PY_TEXT = PyFormat.TEXT + + +class Transformer(AdaptContext): + """ + An object that can adapt efficiently between Python and PostgreSQL. + + The life cycle of the object is the query, so it is assumed that attributes + such as the server version or the connection encoding will not change. The + object have its state so adapting several values of the same type can be + optimised. + + """ + + __module__ = "psycopg.adapt" + + __slots__ = """ + types formats + _conn _adapters _pgresult _dumpers _loaders _encoding _none_oid + _oid_dumpers _oid_types _row_dumpers _row_loaders + """.split() + + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + + _adapters: "AdaptersMap" + _pgresult: Optional["PGresult"] + _none_oid: int + + def __init__(self, context: Optional[AdaptContext] = None): + self._pgresult = self.types = self.formats = None + + # WARNING: don't store context, or you'll create a loop with the Cursor + if context: + self._adapters = context.adapters + self._conn = context.connection + else: + self._adapters = postgres.adapters + self._conn = None + + # mapping fmt, class -> Dumper instance + self._dumpers: DefaultDict[PyFormat, DumperCache] + self._dumpers = defaultdict(dict) + + # mapping fmt, oid -> Dumper instance + # Not often used, so create it only if needed. + self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]] + self._oid_dumpers = None + + # mapping fmt, oid -> Loader instance + self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {}) + + self._row_dumpers: Optional[List["Dumper"]] = None + + # sequence of load functions from value to python + # the length of the result columns + self._row_loaders: List[LoadFunc] = [] + + # mapping oid -> type sql representation + self._oid_types: Dict[int, bytes] = {} + + self._encoding = "" + + @classmethod + def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": + """ + Return a Transformer from an AdaptContext. + + If the context is a Transformer instance, just return it. + """ + if isinstance(context, Transformer): + return context + else: + return cls(context) + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + return self._conn + + @property + def encoding(self) -> str: + if not self._encoding: + conn = self.connection + self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8" + return self._encoding + + @property + def adapters(self) -> "AdaptersMap": + return self._adapters + + @property + def pgresult(self) -> Optional["PGresult"]: + return self._pgresult + + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None, + ) -> None: + self._pgresult = result + + if not result: + self._nfields = self._ntuples = 0 + if set_loaders: + self._row_loaders = [] + return + + self._ntuples = result.ntuples + nf = self._nfields = result.nfields + + if not set_loaders: + return + + if not nf: + self._row_loaders = [] + return + + fmt: pq.Format + fmt = result.fformat(0) if format is None else format # type: ignore + self._row_loaders = [ + self.get_loader(result.ftype(i), fmt).load for i in range(nf) + ] + + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: + self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types] + self.types = tuple(types) + self.formats = [format] * len(types) + + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: + self._row_loaders = [self.get_loader(oid, format).load for oid in types] + + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[Buffer]]: + nparams = len(params) + out: List[Optional[Buffer]] = [None] * nparams + + # If we have dumpers, it means set_dumper_types had been called, in + # which case self.types and self.formats are set to sequences of the + # right size. + if self._row_dumpers: + for i in range(nparams): + param = params[i] + if param is not None: + out[i] = self._row_dumpers[i].dump(param) + return out + + types = [self._get_none_oid()] * nparams + pqformats = [TEXT] * nparams + + for i in range(nparams): + param = params[i] + if param is None: + continue + dumper = self.get_dumper(param, formats[i]) + out[i] = dumper.dump(param) + types[i] = dumper.oid + pqformats[i] = dumper.format + + self.types = tuple(types) + self.formats = pqformats + + return out + + def as_literal(self, obj: Any) -> bytes: + dumper = self.get_dumper(obj, PY_TEXT) + rv = dumper.quote(obj) + # If the result is quoted, and the oid not unknown or text, + # add an explicit type cast. + # Check the last char because the first one might be 'E'. + oid = dumper.oid + if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID: + try: + type_sql = self._oid_types[oid] + except KeyError: + ti = self.adapters.types.get(oid) + if ti: + if oid < 8192: + # builtin: prefer "timestamptz" to "timestamp with time zone" + type_sql = ti.name.encode(self.encoding) + else: + type_sql = ti.regtype.encode(self.encoding) + if oid == ti.array_oid: + type_sql += b"[]" + else: + type_sql = b"" + self._oid_types[oid] = type_sql + + if type_sql: + rv = b"%s::%s" % (rv, type_sql) + + if not isinstance(rv, bytes): + rv = bytes(rv) + return rv + + def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper": + """ + Return a Dumper instance to dump `!obj`. + """ + # Normally, the type of the object dictates how to dump it + key = type(obj) + + # Reuse an existing Dumper class for objects of the same type + cache = self._dumpers[format] + try: + dumper = cache[key] + except KeyError: + # If it's the first time we see this type, look for a dumper + # configured for it. + dcls = self.adapters.get_dumper(key, format) + cache[key] = dumper = dcls(key, self) + + # Check if the dumper requires an upgrade to handle this specific value + key1 = dumper.get_key(obj, format) + if key1 is key: + return dumper + + # If it does, ask the dumper to create its own upgraded version + try: + return cache[key1] + except KeyError: + dumper = cache[key1] = dumper.upgrade(obj, format) + return dumper + + def _get_none_oid(self) -> int: + try: + return self._none_oid + except AttributeError: + pass + + try: + rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid + except KeyError: + raise e.InterfaceError("None dumper not found") + + return rv + + def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper": + """ + Return a Dumper to dump an object to the type with given oid. + """ + if not self._oid_dumpers: + self._oid_dumpers = ({}, {}) + + # Reuse an existing Dumper class for objects of the same type + cache = self._oid_dumpers[format] + try: + return cache[oid] + except KeyError: + # If it's the first time we see this type, look for a dumper + # configured for it. + dcls = self.adapters.get_dumper_by_oid(oid, format) + cache[oid] = dumper = dcls(NoneType, self) + + return dumper + + def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: + res = self._pgresult + if not res: + raise e.InterfaceError("result not set") + + if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples): + raise e.InterfaceError( + f"rows must be included between 0 and {self._ntuples}" + ) + + records = [] + for row in range(row0, row1): + record: List[Any] = [None] * self._nfields + for col in range(self._nfields): + val = res.get_value(row, col) + if val is not None: + record[col] = self._row_loaders[col](val) + records.append(make_row(record)) + + return records + + def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: + res = self._pgresult + if not res: + return None + + if not 0 <= row < self._ntuples: + return None + + record: List[Any] = [None] * self._nfields + for col in range(self._nfields): + val = res.get_value(row, col) + if val is not None: + record[col] = self._row_loaders[col](val) + + return make_row(record) + + def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: + if len(self._row_loaders) != len(record): + raise e.ProgrammingError( + f"cannot load sequence of {len(record)} items:" + f" {len(self._row_loaders)} loaders registered" + ) + + return tuple( + (self._row_loaders[i](val) if val is not None else None) + for i, val in enumerate(record) + ) + + def get_loader(self, oid: int, format: pq.Format) -> "Loader": + try: + return self._loaders[format][oid] + except KeyError: + pass + + loader_cls = self._adapters.get_loader(oid, format) + if not loader_cls: + loader_cls = self._adapters.get_loader(INVALID_OID, format) + if not loader_cls: + raise e.InterfaceError("unknown oid loader not found") + loader = self._loaders[format][oid] = loader_cls(oid, self) + return loader diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py new file mode 100644 index 0000000..2f1a24d --- /dev/null +++ b/psycopg/psycopg/_typeinfo.py @@ -0,0 +1,461 @@ +""" +Information about PostgreSQL types + +These types allow to read information from the system catalog and provide +information to the adapters if needed. +""" + +# Copyright (C) 2020 The Psycopg Team +from enum import Enum +from typing import Any, Dict, Iterator, Optional, overload +from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import errors as e +from .abc import AdaptContext +from .rows import dict_row + +if TYPE_CHECKING: + from .connection import Connection + from .connection_async import AsyncConnection + from .sql import Identifier + +T = TypeVar("T", bound="TypeInfo") +RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]] + + +class TypeInfo: + """ + Hold information about a PostgreSQL base type. + """ + + __module__ = "psycopg.types" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + delimiter: str = ",", + ): + self.name = name + self.oid = oid + self.array_oid = array_oid + self.regtype = regtype or name + self.delimiter = delimiter + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__qualname__}:" + f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>" + ) + + @overload + @classmethod + def fetch( + cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"] + ) -> Optional[T]: + ... + + @overload + @classmethod + async def fetch( + cls: Type[T], + conn: "AsyncConnection[Any]", + name: Union[str, "Identifier"], + ) -> Optional[T]: + ... + + @classmethod + def fetch( + cls: Type[T], + conn: "Union[Connection[Any], AsyncConnection[Any]]", + name: Union[str, "Identifier"], + ) -> Any: + """Query a system catalog to read information about a type.""" + from .sql import Composable + from .connection_async import AsyncConnection + + if isinstance(name, Composable): + name = name.as_string(conn) + + if isinstance(conn, AsyncConnection): + return cls._fetch_async(conn, name) + + # This might result in a nested transaction. What we want is to leave + # the function with the connection in the state we found (either idle + # or intrans) + try: + with conn.transaction(): + with conn.cursor(binary=True, row_factory=dict_row) as cur: + cur.execute(cls._get_info_query(conn), {"name": name}) + recs = cur.fetchall() + except e.UndefinedObject: + return None + + return cls._from_records(name, recs) + + @classmethod + async def _fetch_async( + cls: Type[T], conn: "AsyncConnection[Any]", name: str + ) -> Optional[T]: + """ + Query a system catalog to read information about a type. + + Similar to `fetch()` but can use an asynchronous connection. + """ + try: + async with conn.transaction(): + async with conn.cursor(binary=True, row_factory=dict_row) as cur: + await cur.execute(cls._get_info_query(conn), {"name": name}) + recs = await cur.fetchall() + except e.UndefinedObject: + return None + + return cls._from_records(name, recs) + + @classmethod + def _from_records( + cls: Type[T], name: str, recs: Sequence[Dict[str, Any]] + ) -> Optional[T]: + if len(recs) == 1: + return cls(**recs[0]) + elif not recs: + return None + else: + raise e.ProgrammingError(f"found {len(recs)} different types named {name}") + + def register(self, context: Optional[AdaptContext] = None) -> None: + """ + Register the type information, globally or in the specified `!context`. + """ + if context: + types = context.adapters.types + else: + from . import postgres + + types = postgres.types + + types.add(self) + + if self.array_oid: + from .types.array import register_array + + register_array(self, context) + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT + typname AS name, oid, typarray AS array_oid, + oid::regtype::text AS regtype, typdelim AS delimiter +FROM pg_type t +WHERE t.oid = %(name)s::regtype +ORDER BY t.oid +""" + + def _added(self, registry: "TypesRegistry") -> None: + """Method called by the `!registry` when the object is added there.""" + pass + + +class RangeInfo(TypeInfo): + """Manage information about a range type.""" + + __module__ = "psycopg.types.range" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + subtype_oid: int, + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.subtype_oid = subtype_oid + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + r.rngsubtype AS subtype_oid +FROM pg_type t +JOIN pg_range r ON t.oid = r.rngtypid +WHERE t.oid = %(name)s::regtype +""" + + def _added(self, registry: "TypesRegistry") -> None: + # Map ranges subtypes to info + registry._registry[RangeInfo, self.subtype_oid] = self + + +class MultirangeInfo(TypeInfo): + """Manage information about a multirange type.""" + + # TODO: expose to multirange module once added + # __module__ = "psycopg.types.multirange" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + range_oid: int, + subtype_oid: int, + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.range_oid = range_oid + self.subtype_oid = subtype_oid + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + if conn.info.server_version < 140000: + raise e.NotSupportedError( + "multirange types are only available from PostgreSQL 14" + ) + return """\ +SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid +FROM pg_type t +JOIN pg_range r ON t.oid = r.rngmultitypid +WHERE t.oid = %(name)s::regtype +""" + + def _added(self, registry: "TypesRegistry") -> None: + # Map multiranges ranges and subtypes to info + registry._registry[MultirangeInfo, self.range_oid] = self + registry._registry[MultirangeInfo, self.subtype_oid] = self + + +class CompositeInfo(TypeInfo): + """Manage information about a composite type.""" + + __module__ = "psycopg.types.composite" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + field_names: Sequence[str], + field_types: Sequence[int], + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.field_names = field_names + self.field_types = field_types + # Will be set by register() if the `factory` is a type + self.python_type: Optional[type] = None + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT + t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + coalesce(a.fnames, '{}') AS field_names, + coalesce(a.ftypes, '{}') AS field_types +FROM pg_type t +LEFT JOIN ( + SELECT + attrelid, + array_agg(attname) AS fnames, + array_agg(atttypid) AS ftypes + FROM ( + SELECT a.attrelid, a.attname, a.atttypid + FROM pg_attribute a + JOIN pg_type t ON t.typrelid = a.attrelid + WHERE t.oid = %(name)s::regtype + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + ) x + GROUP BY attrelid +) a ON a.attrelid = t.typrelid +WHERE t.oid = %(name)s::regtype +""" + + +class EnumInfo(TypeInfo): + """Manage information about an enum type.""" + + __module__ = "psycopg.types.enum" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + labels: Sequence[str], + ): + super().__init__(name, oid, array_oid) + self.labels = labels + # Will be set by register_enum() + self.enum: Optional[Type[Enum]] = None + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT name, oid, array_oid, array_agg(label) AS labels +FROM ( + SELECT + t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + e.enumlabel AS label + FROM pg_type t + LEFT JOIN pg_enum e + ON e.enumtypid = t.oid + WHERE t.oid = %(name)s::regtype + ORDER BY e.enumsortorder +) x +GROUP BY name, oid, array_oid +""" + + +class TypesRegistry: + """ + Container for the information about types in a database. + """ + + __module__ = "psycopg.types" + + def __init__(self, template: Optional["TypesRegistry"] = None): + self._registry: Dict[RegistryKey, TypeInfo] + + # Make a shallow copy: it will become a proper copy if the registry + # is edited. + if template: + self._registry = template._registry + self._own_state = False + template._own_state = False + else: + self.clear() + + def clear(self) -> None: + self._registry = {} + self._own_state = True + + def add(self, info: TypeInfo) -> None: + self._ensure_own_state() + if info.oid: + self._registry[info.oid] = info + if info.array_oid: + self._registry[info.array_oid] = info + self._registry[info.name] = info + + if info.regtype and info.regtype not in self._registry: + self._registry[info.regtype] = info + + # Allow info to customise further their relation with the registry + info._added(self) + + def __iter__(self) -> Iterator[TypeInfo]: + seen = set() + for t in self._registry.values(): + if id(t) not in seen: + seen.add(id(t)) + yield t + + @overload + def __getitem__(self, key: Union[str, int]) -> TypeInfo: + ... + + @overload + def __getitem__(self, key: Tuple[Type[T], int]) -> T: + ... + + def __getitem__(self, key: RegistryKey) -> TypeInfo: + """ + Return info about a type, specified by name or oid + + :param key: the name or oid of the type to look for. + + Raise KeyError if not found. + """ + if isinstance(key, str): + if key.endswith("[]"): + key = key[:-2] + elif not isinstance(key, (int, tuple)): + raise TypeError(f"the key must be an oid or a name, got {type(key)}") + try: + return self._registry[key] + except KeyError: + raise KeyError(f"couldn't find the type {key!r} in the types registry") + + @overload + def get(self, key: Union[str, int]) -> Optional[TypeInfo]: + ... + + @overload + def get(self, key: Tuple[Type[T], int]) -> Optional[T]: + ... + + def get(self, key: RegistryKey) -> Optional[TypeInfo]: + """ + Return info about a type, specified by name or oid + + :param key: the name or oid of the type to look for. + + Unlike `__getitem__`, return None if not found. + """ + try: + return self[key] + except KeyError: + return None + + def get_oid(self, name: str) -> int: + """ + Return the oid of a PostgreSQL type by name. + + :param key: the name of the type to look for. + + Return the array oid if the type ends with "``[]``" + + Raise KeyError if the name is unknown. + """ + t = self[name] + if name.endswith("[]"): + return t.array_oid + else: + return t.oid + + def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]: + """ + Return info about a `TypeInfo` subclass by its element name or oid. + + :param cls: the subtype of `!TypeInfo` to look for. Currently + supported are `~psycopg.types.range.RangeInfo` and + `~psycopg.types.multirange.MultirangeInfo`. + :param subtype: The name or OID of the subtype of the element to look for. + :return: The `!TypeInfo` object of class `!cls` whose subtype is + `!subtype`. `!None` if the element or its range are not found. + """ + try: + info = self[subtype] + except KeyError: + return None + return self.get((cls, info.oid)) + + def _ensure_own_state(self) -> None: + # Time to write! so, copy. + if not self._own_state: + self._registry = self._registry.copy() + self._own_state = True diff --git a/psycopg/psycopg/_tz.py b/psycopg/psycopg/_tz.py new file mode 100644 index 0000000..813ed62 --- /dev/null +++ b/psycopg/psycopg/_tz.py @@ -0,0 +1,44 @@ +""" +Timezone utility functions. +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +from typing import Dict, Optional, Union +from datetime import timezone, tzinfo + +from .pq.abc import PGconn +from ._compat import ZoneInfo + +logger = logging.getLogger("psycopg") + +_timezones: Dict[Union[None, bytes], tzinfo] = { + None: timezone.utc, + b"UTC": timezone.utc, +} + + +def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo: + """Return the Python timezone info of the connection's timezone.""" + tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None + try: + return _timezones[tzname] + except KeyError: + sname = tzname.decode() if tzname else "UTC" + try: + zi: tzinfo = ZoneInfo(sname) + except (KeyError, OSError): + logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname) + zi = timezone.utc + except Exception as ex: + logger.warning( + "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)", + sname, + type(ex).__name__, + ex, + ) + zi = timezone.utc + + _timezones[tzname] = zi + return zi diff --git a/psycopg/psycopg/_wrappers.py b/psycopg/psycopg/_wrappers.py new file mode 100644 index 0000000..f861741 --- /dev/null +++ b/psycopg/psycopg/_wrappers.py @@ -0,0 +1,137 @@ +""" +Wrappers for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Wrappers to force numbers to be cast as specific PostgreSQL types + +# These types are implemented here but exposed by `psycopg.types.numeric`. +# They are defined here to avoid a circular import. +_MODULE = "psycopg.types.numeric" + + +class Int2(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int2": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Int4(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int4": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Int8(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int8": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class IntNumeric(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "IntNumeric": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Float4(float): + """ + Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: float) -> "Float4": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Float8(float): + """ + Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: float) -> "Float8": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Oid(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`oid`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Oid": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py new file mode 100644 index 0000000..80c8fbf --- /dev/null +++ b/psycopg/psycopg/abc.py @@ -0,0 +1,266 @@ +""" +Protocol objects representing different implementations of the same classes. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, Generator, Mapping +from typing import List, Optional, Sequence, Tuple, TypeVar, Union +from typing import TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import pq +from ._enums import PyFormat as PyFormat +from ._compat import Protocol, LiteralString + +if TYPE_CHECKING: + from . import sql + from .rows import Row, RowMaker + from .pq.abc import PGresult + from .waiting import Wait, Ready + from .connection import BaseConnection + from ._adapters_map import AdaptersMap + +NoneType: type = type(None) + +# An object implementing the buffer protocol +Buffer: TypeAlias = Union[bytes, bytearray, memoryview] + +Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"] +Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]] +ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]") +PipelineCommand: TypeAlias = Callable[[], None] +DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]] + +# Waiting protocol types + +RV = TypeVar("RV") + +PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV] +"""Generator for processes where the connection file number can change. + +This can happen in connection and reset, but not in normal querying. +""" + +PQGen: TypeAlias = Generator["Wait", "Ready", RV] +"""Generator for processes where the connection file number won't change. +""" + + +class WaitFunc(Protocol): + """ + Wait on the connection which generated `PQgen` and return its final result. + """ + + def __call__( + self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None + ) -> RV: + ... + + +# Adaptation types + +DumpFunc: TypeAlias = Callable[[Any], Buffer] +LoadFunc: TypeAlias = Callable[[Buffer], Any] + + +class AdaptContext(Protocol): + """ + A context describing how types are adapted. + + Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`, + `~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`. + + Note that this is a `~typing.Protocol`, so objects implementing + `!AdaptContext` don't need to explicitly inherit from this class. + + """ + + @property + def adapters(self) -> "AdaptersMap": + """The adapters configuration that this object uses.""" + ... + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + """The connection used by this object, if available. + + :rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None` + """ + ... + + +class Dumper(Protocol): + """ + Convert Python objects of type `!cls` to PostgreSQL representation. + """ + + format: pq.Format + """ + The format that this class `dump()` method produces, + `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`. + + This is a class attribute. + """ + + oid: int + """The oid to pass to the server, if known; 0 otherwise (class attribute).""" + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + ... + + def dump(self, obj: Any) -> Buffer: + """Convert the object `!obj` to PostgreSQL representation. + + :param obj: the object to convert. + """ + ... + + def quote(self, obj: Any) -> Buffer: + """Convert the object `!obj` to escaped representation. + + :param obj: the object to convert. + """ + ... + + def get_key(self, obj: Any, format: PyFormat) -> DumperKey: + """Return an alternative key to upgrade the dumper to represent `!obj`. + + :param obj: The object to convert + :param format: The format to convert to + + Normally the type of the object is all it takes to define how to dump + the object to the database. For instance, a Python `~datetime.date` can + be simply converted into a PostgreSQL :sql:`date`. + + In a few cases, just the type is not enough. For example: + + - A Python `~datetime.datetime` could be represented as a + :sql:`timestamptz` or a :sql:`timestamp`, according to whether it + specifies a `!tzinfo` or not. + + - A Python int could be stored as several Postgres types: int2, int4, + int8, numeric. If a type too small is used, it may result in an + overflow. If a type too large is used, PostgreSQL may not want to + cast it to a smaller type. + + - Python lists should be dumped according to the type they contain to + convert them to e.g. array of strings, array of ints (and which + size of int?...) + + In these cases, a dumper can implement `!get_key()` and return a new + class, or sequence of classes, that can be used to identify the same + dumper again. If the mechanism is not needed, the method should return + the same `!cls` object passed in the constructor. + + If a dumper implements `get_key()` it should also implement + `upgrade()`. + + """ + ... + + def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + """Return a new dumper to manage `!obj`. + + :param obj: The object to convert + :param format: The format to convert to + + Once `Transformer.get_dumper()` has been notified by `get_key()` that + this Dumper class cannot handle `!obj` itself, it will invoke + `!upgrade()`, which should return a new `Dumper` instance, which will + be reused for every objects for which `!get_key()` returns the same + result. + """ + ... + + +class Loader(Protocol): + """ + Convert PostgreSQL values with type OID `!oid` to Python objects. + """ + + format: pq.Format + """ + The format that this class `load()` method can convert, + `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`. + + This is a class attribute. + """ + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + ... + + def load(self, data: Buffer) -> Any: + """ + Convert the data returned by the database into a Python object. + + :param data: the data to convert. + """ + ... + + +class Transformer(Protocol): + + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + + def __init__(self, context: Optional[AdaptContext] = None): + ... + + @classmethod + def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": + ... + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + ... + + @property + def encoding(self) -> str: + ... + + @property + def adapters(self) -> "AdaptersMap": + ... + + @property + def pgresult(self) -> Optional["PGresult"]: + ... + + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None + ) -> None: + ... + + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: + ... + + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: + ... + + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[Buffer]]: + ... + + def as_literal(self, obj: Any) -> bytes: + ... + + def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: + ... + + def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]: + ... + + def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: + ... + + def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: + ... + + def get_loader(self, oid: int, format: pq.Format) -> Loader: + ... diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py new file mode 100644 index 0000000..7ec4a55 --- /dev/null +++ b/psycopg/psycopg/adapt.py @@ -0,0 +1,162 @@ +""" +Entry point into the adaptation system. +""" + +# Copyright (C) 2020 The Psycopg Team + +from abc import ABC, abstractmethod +from typing import Any, Optional, Type, TYPE_CHECKING + +from . import pq, abc +from . import _adapters_map +from ._enums import PyFormat as PyFormat +from ._cmodule import _psycopg + +if TYPE_CHECKING: + from .connection import BaseConnection + +AdaptersMap = _adapters_map.AdaptersMap +Buffer = abc.Buffer + +ORD_BS = ord("\\") + + +class Dumper(abc.Dumper, ABC): + """ + Convert Python object of the type `!cls` to PostgreSQL representation. + """ + + oid: int = 0 + """The oid to pass to the server, if known.""" + + format: pq.Format = pq.Format.TEXT + """The format of the data dumped.""" + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + self.cls = cls + self.connection: Optional["BaseConnection[Any]"] = ( + context.connection if context else None + ) + + def __repr__(self) -> str: + return ( + f"<{type(self).__module__}.{type(self).__qualname__}" + f" (oid={self.oid}) at 0x{id(self):x}>" + ) + + @abstractmethod + def dump(self, obj: Any) -> Buffer: + ... + + def quote(self, obj: Any) -> Buffer: + """ + By default return the `dump()` value quoted and sanitised, so + that the result can be used to build a SQL string. This works well + for most types and you won't likely have to implement this method in a + subclass. + """ + value = self.dump(obj) + + if self.connection: + esc = pq.Escaping(self.connection.pgconn) + # escaping and quoting + return esc.escape_literal(value) + + # This path is taken when quote is asked without a connection, + # usually it means by psycopg.sql.quote() or by + # 'Composible.as_string(None)'. Most often than not this is done by + # someone generating a SQL file to consume elsewhere. + + # No quoting, only quote escaping, random bs escaping. See further. + esc = pq.Escaping() + out = esc.escape_string(value) + + # b"\\" in memoryview doesn't work so search for the ascii value + if ORD_BS not in out: + # If the string has no backslash, the result is correct and we + # don't need to bother with standard_conforming_strings. + return b"'" + out + b"'" + + # The libpq has a crazy behaviour: PQescapeString uses the last + # standard_conforming_strings setting seen on a connection. This + # means that backslashes might be escaped or might not. + # + # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH, + # if scs is off, '\\' raises a warning and '\' is an error. + # + # Check what the libpq does, and if it doesn't escape the backslash + # let's do it on our own. Never mind the race condition. + rv: bytes = b" E'" + out + b"'" + if esc.escape_string(b"\\") == b"\\": + rv = rv.replace(b"\\", b"\\\\") + return rv + + def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey: + """ + Implementation of the `~psycopg.abc.Dumper.get_key()` member of the + `~psycopg.abc.Dumper` protocol. Look at its definition for details. + + This implementation returns the `!cls` passed in the constructor. + Subclasses needing to specialise the PostgreSQL type according to the + *value* of the object dumped (not only according to to its type) + should override this class. + + """ + return self.cls + + def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + """ + Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the + `~psycopg.abc.Dumper` protocol. Look at its definition for details. + + This implementation just returns `!self`. If a subclass implements + `get_key()` it should probably override `!upgrade()` too. + """ + return self + + +class Loader(abc.Loader, ABC): + """ + Convert PostgreSQL values with type OID `!oid` to Python objects. + """ + + format: pq.Format = pq.Format.TEXT + """The format of the data loaded.""" + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + self.oid = oid + self.connection: Optional["BaseConnection[Any]"] = ( + context.connection if context else None + ) + + @abstractmethod + def load(self, data: Buffer) -> Any: + """Convert a PostgreSQL value to a Python object.""" + ... + + +Transformer: Type["abc.Transformer"] + +# Override it with fast object if available +if _psycopg: + Transformer = _psycopg.Transformer +else: + from . import _transform + + Transformer = _transform.Transformer + + +class RecursiveDumper(Dumper): + """Dumper with a transformer to help dumping recursive types.""" + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + super().__init__(cls, context) + self._tx = Transformer.from_context(context) + + +class RecursiveLoader(Loader): + """Loader with a transformer to help loading recursive types.""" + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer.from_context(context) diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py new file mode 100644 index 0000000..6271ec5 --- /dev/null +++ b/psycopg/psycopg/client_cursor.py @@ -0,0 +1,95 @@ +""" +psycopg client-side binding cursors +""" + +# Copyright (C) 2022 The Psycopg Team + +from typing import Optional, Tuple, TYPE_CHECKING +from functools import partial + +from ._queries import PostgresQuery, PostgresClientQuery + +from . import pq +from . import adapt +from . import errors as e +from .abc import ConnectionType, Query, Params +from .rows import Row +from .cursor import BaseCursor, Cursor +from ._preparing import Prepare +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from typing import Any # noqa: F401 + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + + +class ClientCursorMixin(BaseCursor[ConnectionType, Row]): + def mogrify(self, query: Query, params: Optional[Params] = None) -> str: + """ + Return the query and parameters merged. + + Parameters are adapted and merged to the query the same way that + `!execute()` would do. + + """ + self._tx = adapt.Transformer(self) + pgq = self._convert_query(query, params) + return pgq.query.decode(self._tx.encoding) + + def _execute_send( + self, + query: PostgresQuery, + *, + force_extended: bool = False, + binary: Optional[bool] = None, + ) -> None: + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + if fmt == BINARY: + raise e.NotSupportedError( + "client-side cursors don't support binary results" + ) + + self._query = query + + if self._conn._pipeline: + # In pipeline mode always use PQsendQueryParams - see #314 + # Multiple statements in the same query are not allowed anyway. + self._conn._pipeline.command_queue.append( + partial(self._pgconn.send_query_params, query.query, None) + ) + elif force_extended: + self._pgconn.send_query_params(query.query, None) + else: + # If we can, let's use simple query protocol, + # as it can execute more than one statement in a single query. + self._pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresClientQuery(self._tx) + pgq.convert(query, params) + return pgq + + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return (Prepare.NO, b"") + + +class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]): + __module__ = "psycopg" + + +class AsyncClientCursor( + ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row] +): + __module__ = "psycopg" diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py new file mode 100644 index 0000000..78ad577 --- /dev/null +++ b/psycopg/psycopg/connection.py @@ -0,0 +1,1031 @@ +""" +psycopg connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +import threading +from types import TracebackType +from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator +from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union +from typing import overload, TYPE_CHECKING +from weakref import ref, ReferenceType +from warnings import warn +from functools import partial +from contextlib import contextmanager +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from . import waiting +from . import postgres +from .abc import AdaptContext, ConnectionType, Params, Query, RV +from .abc import PQGen, PQGenConn +from .sql import Composable, SQL +from ._tpc import Xid +from .rows import Row, RowFactory, tuple_row, TupleRow, args_row +from .adapt import AdaptersMap +from ._enums import IsolationLevel +from .cursor import Cursor +from ._compat import LiteralString +from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from ._pipeline import BasePipeline, Pipeline +from .generators import notifies, connect, execute +from ._encodings import pgconn_encoding +from ._preparing import PrepareManager +from .transaction import Transaction +from .server_cursor import ServerCursor + +if TYPE_CHECKING: + from .pq.abc import PGconn, PGresult + from psycopg_pool.base import BasePool + + +# Row Type variable for Cursor (when it needs to be distinguished from the +# connection's one) +CursorRow = TypeVar("CursorRow") + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +OK = pq.ConnStatus.OK +BAD = pq.ConnStatus.BAD + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + +logger = logging.getLogger("psycopg") + + +class Notify(NamedTuple): + """An asynchronous notification received from the database.""" + + channel: str + """The name of the channel on which the notification was received.""" + + payload: str + """The message attached to the notification.""" + + pid: int + """The PID of the backend process which sent the notification.""" + + +Notify.__module__ = "psycopg" + +NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] +NotifyHandler: TypeAlias = Callable[[Notify], None] + + +class BaseConnection(Generic[Row]): + """ + Base class for different types of connections. + + Share common functionalities such as access to the wrapped PGconn, but + allow different interfaces (sync/async). + """ + + # DBAPI2 exposed exceptions + Warning = e.Warning + Error = e.Error + InterfaceError = e.InterfaceError + DatabaseError = e.DatabaseError + DataError = e.DataError + OperationalError = e.OperationalError + IntegrityError = e.IntegrityError + InternalError = e.InternalError + ProgrammingError = e.ProgrammingError + NotSupportedError = e.NotSupportedError + + # Enums useful for the connection + ConnStatus = pq.ConnStatus + TransactionStatus = pq.TransactionStatus + + def __init__(self, pgconn: "PGconn"): + self.pgconn = pgconn + self._autocommit = False + + # None, but set to a copy of the global adapters map as soon as requested. + self._adapters: Optional[AdaptersMap] = None + + self._notice_handlers: List[NoticeHandler] = [] + self._notify_handlers: List[NotifyHandler] = [] + + # Number of transaction blocks currently entered + self._num_transactions = 0 + + self._closed = False # closed by an explicit close() + self._prepared: PrepareManager = PrepareManager() + self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared + + wself = ref(self) + pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) + pgconn.notify_handler = partial(BaseConnection._notify_handler, wself) + + # Attribute is only set if the connection is from a pool so we can tell + # apart a connection in the pool too (when _pool = None) + self._pool: Optional["BasePool[Any]"] + + self._pipeline: Optional[BasePipeline] = None + + # Time after which the connection should be closed + self._expire_at: float + + self._isolation_level: Optional[IsolationLevel] = None + self._read_only: Optional[bool] = None + self._deferrable: Optional[bool] = None + self._begin_statement = b"" + + def __del__(self) -> None: + # If fails on connection we might not have this attribute yet + if not hasattr(self, "pgconn"): + return + + # Connection correctly closed + if self.closed: + return + + # Connection in a pool so terminating with the program is normal + if hasattr(self, "_pool"): + return + + warn( + f"connection {self} was deleted while still open." + " Please use 'with' or '.close()' to close the connection", + ResourceWarning, + ) + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self.pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + @property + def closed(self) -> bool: + """`!True` if the connection is closed.""" + return self.pgconn.status == BAD + + @property + def broken(self) -> bool: + """ + `!True` if the connection was interrupted. + + A broken connection is always `closed`, but wasn't closed in a clean + way, such as using `close()` or a `!with` block. + """ + return self.pgconn.status == BAD and not self._closed + + @property + def autocommit(self) -> bool: + """The autocommit state of the connection.""" + return self._autocommit + + @autocommit.setter + def autocommit(self, value: bool) -> None: + self._set_autocommit(value) + + def _set_autocommit(self, value: bool) -> None: + raise NotImplementedError + + def _set_autocommit_gen(self, value: bool) -> PQGen[None]: + yield from self._check_intrans_gen("autocommit") + self._autocommit = bool(value) + + @property + def isolation_level(self) -> Optional[IsolationLevel]: + """ + The isolation level of the new transactions started on the connection. + """ + return self._isolation_level + + @isolation_level.setter + def isolation_level(self, value: Optional[IsolationLevel]) -> None: + self._set_isolation_level(value) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + raise NotImplementedError + + def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]: + yield from self._check_intrans_gen("isolation_level") + self._isolation_level = IsolationLevel(value) if value is not None else None + self._begin_statement = b"" + + @property + def read_only(self) -> Optional[bool]: + """ + The read-only state of the new transactions started on the connection. + """ + return self._read_only + + @read_only.setter + def read_only(self, value: Optional[bool]) -> None: + self._set_read_only(value) + + def _set_read_only(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("read_only") + self._read_only = bool(value) + self._begin_statement = b"" + + @property + def deferrable(self) -> Optional[bool]: + """ + The deferrable state of the new transactions started on the connection. + """ + return self._deferrable + + @deferrable.setter + def deferrable(self, value: Optional[bool]) -> None: + self._set_deferrable(value) + + def _set_deferrable(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("deferrable") + self._deferrable = bool(value) + self._begin_statement = b"" + + def _check_intrans_gen(self, attribute: str) -> PQGen[None]: + # Raise an exception if we are in a transaction + status = self.pgconn.transaction_status + if status == IDLE and self._pipeline: + yield from self._pipeline._sync_gen() + status = self.pgconn.transaction_status + if status != IDLE: + if self._num_transactions: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection.transaction() context in progress" + ) + else: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection in transaction status " + f"{pq.TransactionStatus(status).name}" + ) + + @property + def info(self) -> ConnectionInfo: + """A `ConnectionInfo` attribute to inspect connection properties.""" + return ConnectionInfo(self.pgconn) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + self._adapters = AdaptersMap(postgres.adapters) + + return self._adapters + + @property + def connection(self) -> "BaseConnection[Row]": + # implement the AdaptContext protocol + return self + + def fileno(self) -> int: + """Return the file descriptor of the connection. + + This function allows to use the connection as file-like object in + functions waiting for readiness, such as the ones defined in the + `selectors` module. + """ + return self.pgconn.socket + + def cancel(self) -> None: + """Cancel the current operation on the connection.""" + # No-op if the connection is closed + # this allows to use the method as callback handler without caring + # about its life. + if self.closed: + return + + if self._tpc and self._tpc[1]: + raise e.ProgrammingError( + "cancel() cannot be used with a prepared two-phase transaction" + ) + + c = self.pgconn.get_cancel() + c.cancel() + + def add_notice_handler(self, callback: NoticeHandler) -> None: + """ + Register a callable to be invoked when a notice message is received. + + :param callback: the callback to call upon message received. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.append(callback) + + def remove_notice_handler(self, callback: NoticeHandler) -> None: + """ + Unregister a notice message callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.remove(callback) + + @staticmethod + def _notice_handler( + wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult" + ) -> None: + self = wself() + if not (self and self._notice_handlers): + return + + diag = e.Diagnostic(res, pgconn_encoding(self.pgconn)) + for cb in self._notice_handlers: + try: + cb(diag) + except Exception as ex: + logger.exception("error processing notice callback '%s': %s", cb, ex) + + def add_notify_handler(self, callback: NotifyHandler) -> None: + """ + Register a callable to be invoked whenever a notification is received. + + :param callback: the callback to call upon notification received. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.append(callback) + + def remove_notify_handler(self, callback: NotifyHandler) -> None: + """ + Unregister a notification callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.remove(callback) + + @staticmethod + def _notify_handler( + wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify + ) -> None: + self = wself() + if not (self and self._notify_handlers): + return + + enc = pgconn_encoding(self.pgconn) + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + for cb in self._notify_handlers: + cb(n) + + @property + def prepare_threshold(self) -> Optional[int]: + """ + Number of times a query is executed before it is prepared. + + - If it is set to 0, every query is prepared the first time it is + executed. + - If it is set to `!None`, prepared statements are disabled on the + connection. + + Default value: 5 + """ + return self._prepared.prepare_threshold + + @prepare_threshold.setter + def prepare_threshold(self, value: Optional[int]) -> None: + self._prepared.prepare_threshold = value + + @property + def prepared_max(self) -> int: + """ + Maximum number of prepared statements on the connection. + + Default value: 100 + """ + return self._prepared.prepared_max + + @prepared_max.setter + def prepared_max(self, value: int) -> None: + self._prepared.prepared_max = value + + # Generators to perform high-level operations on the connection + # + # These operations are expressed in terms of non-blocking generators + # and the task of waiting when needed (when the generators yield) is left + # to the connections subclass, which might wait either in blocking mode + # or through asyncio. + # + # All these generators assume exclusive access to the connection: subclasses + # should have a lock and hold it before calling and consuming them. + + @classmethod + def _connect_gen( + cls: Type[ConnectionType], + conninfo: str = "", + *, + autocommit: bool = False, + ) -> PQGenConn[ConnectionType]: + """Generator to connect to the database and create a new instance.""" + pgconn = yield from connect(conninfo) + conn = cls(pgconn) + conn._autocommit = bool(autocommit) + return conn + + def _exec_command( + self, command: Query, result_format: pq.Format = TEXT + ) -> PQGen[Optional["PGresult"]]: + """ + Generator to send a command and receive the result to the backend. + + Only used to implement internal commands such as "commit", with eventual + arguments bound client-side. The cursor can do more complex stuff. + """ + self._check_connection_ok() + + if isinstance(command, str): + command = command.encode(pgconn_encoding(self.pgconn)) + elif isinstance(command, Composable): + command = command.as_bytes(self) + + if self._pipeline: + cmd = partial( + self.pgconn.send_query_params, + command, + None, + result_format=result_format, + ) + self._pipeline.command_queue.append(cmd) + self._pipeline.result_queue.append(None) + return None + + self.pgconn.send_query_params(command, None, result_format=result_format) + + result = (yield from execute(self.pgconn))[-1] + if result.status != COMMAND_OK and result.status != TUPLES_OK: + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + else: + raise e.InterfaceError( + f"unexpected result {pq.ExecStatus(result.status).name}" + f" from command {command.decode()!r}" + ) + return result + + def _check_connection_ok(self) -> None: + if self.pgconn.status == OK: + return + + if self.pgconn.status == BAD: + raise e.OperationalError("the connection is closed") + raise e.InterfaceError( + "cannot execute operations: the connection is" + f" in status {self.pgconn.status}" + ) + + def _start_query(self) -> PQGen[None]: + """Generator to start a transaction if necessary.""" + if self._autocommit: + return + + if self.pgconn.transaction_status != IDLE: + return + + yield from self._exec_command(self._get_tx_start_command()) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _get_tx_start_command(self) -> bytes: + if self._begin_statement: + return self._begin_statement + + parts = [b"BEGIN"] + + if self.isolation_level is not None: + val = IsolationLevel(self.isolation_level) + parts.append(b"ISOLATION LEVEL") + parts.append(val.name.replace("_", " ").encode()) + + if self.read_only is not None: + parts.append(b"READ ONLY" if self.read_only else b"READ WRITE") + + if self.deferrable is not None: + parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE") + + self._begin_statement = b" ".join(parts) + return self._begin_statement + + def _commit_gen(self) -> PQGen[None]: + """Generator implementing `Connection.commit()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit commit() forbidden within a Transaction " + "context. (Transaction will be automatically committed " + "on successful exit from context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "commit() cannot be used during a two-phase transaction" + ) + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"COMMIT") + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _rollback_gen(self) -> PQGen[None]: + """Generator implementing `Connection.rollback()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit rollback() forbidden within a Transaction " + "context. (Either raise Rollback() or allow " + "an exception to propagate out of the context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "rollback() cannot be used during a two-phase transaction" + ) + + # Get out of a "pipeline aborted" state + if self._pipeline: + yield from self._pipeline._sync_gen() + + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"ROLLBACK") + self._prepared.clear() + for cmd in self._prepared.get_maintenance_commands(): + yield from self._exec_command(cmd) + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: + """ + Returns a `Xid` to pass to the `!tpc_*()` methods of this connection. + + The argument types and constraints are explained in + :ref:`two-phase-commit`. + + The values passed to the method will be available on the returned + object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`. + """ + self._check_tpc() + return Xid.from_parts(format_id, gtrid, bqual) + + def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]: + self._check_tpc() + + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self.pgconn.transaction_status != IDLE: + raise e.ProgrammingError( + "can't start two-phase transaction: connection in status" + f" {pq.TransactionStatus(self.pgconn.transaction_status).name}" + ) + + if self._autocommit: + raise e.ProgrammingError( + "can't use two-phase transactions in autocommit mode" + ) + + self._tpc = (xid, False) + yield from self._exec_command(self._get_tx_start_command()) + + def _tpc_prepare_gen(self) -> PQGen[None]: + if not self._tpc: + raise e.ProgrammingError( + "'tpc_prepare()' must be called inside a two-phase transaction" + ) + if self._tpc[1]: + raise e.ProgrammingError( + "'tpc_prepare()' cannot be used during a prepared two-phase transaction" + ) + xid = self._tpc[0] + self._tpc = (xid, True) + yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid))) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _tpc_finish_gen( + self, action: LiteralString, xid: Union[Xid, str, None] + ) -> PQGen[None]: + fname = f"tpc_{action.lower()}()" + if xid is None: + if not self._tpc: + raise e.ProgrammingError( + f"{fname} without xid must must be" + " called inside a two-phase transaction" + ) + xid = self._tpc[0] + else: + if self._tpc: + raise e.ProgrammingError( + f"{fname} with xid must must be called" + " outside a two-phase transaction" + ) + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self._tpc and not self._tpc[1]: + meth: Callable[[], PQGen[None]] + meth = getattr(self, f"_{action.lower()}_gen") + self._tpc = None + yield from meth() + else: + yield from self._exec_command( + SQL("{} PREPARED {}").format(SQL(action), str(xid)) + ) + self._tpc = None + + def _check_tpc(self) -> None: + """Raise NotSupportedError if TPC is not supported.""" + # TPC supported on every supported PostgreSQL version. + pass + + +class Connection(BaseConnection[Row]): + """ + Wrapper for a connection to the database. + """ + + __module__ = "psycopg" + + cursor_factory: Type[Cursor[Row]] + server_cursor_factory: Type[ServerCursor[Row]] + row_factory: RowFactory[Row] + _pipeline: Optional[Pipeline] + _Self = TypeVar("_Self", bound="Connection[Any]") + + def __init__( + self, + pgconn: "PGconn", + row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row), + ): + super().__init__(pgconn) + self.row_factory = row_factory + self.lock = threading.Lock() + self.cursor_factory = Cursor + self.server_cursor_factory = ServerCursor + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[Row], + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[Cursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "Connection[Row]": + # TODO: returned type should be _Self. See #308. + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[Cursor[Any]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "Connection[TupleRow]": + ... + + @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: Optional[RowFactory[Row]] = None, + cursor_factory: Optional[Type[Cursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Any, + ) -> "Connection[Any]": + """ + Connect to a database server and return a new `Connection` instance. + """ + params = cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + + try: + rv = cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit), + timeout=params["connect_timeout"], + ) + except e.Error as ex: + raise ex.with_traceback(None) + + if row_factory: + rv.row_factory = row_factory + if cursor_factory: + rv.cursor_factory = cursor_factory + if context: + rv._adapters = AdaptersMap(context.adapters) + rv.prepare_threshold = prepare_threshold + return rv + + def __enter__(self: _Self) -> _Self: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.closed: + return + + if exc_type: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + self.rollback() + except Exception as exc2: + logger.warning( + "error ignored in rollback on %s: %s", + self, + exc2, + ) + else: + self.commit() + + # Close the connection only if it doesn't belong to a pool. + if not getattr(self, "_pool", None): + self.close() + + @classmethod + def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]: + """Manipulate connection parameters before connecting. + + :param conninfo: Connection string as received by `~Connection.connect()`. + :param kwargs: Overriding connection arguments as received by `!connect()`. + :return: Connection arguments merged and eventually modified, in a + format similar to `~conninfo.conninfo_to_dict()`. + """ + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + return params + + def close(self) -> None: + """Close the database connection.""" + if self.closed: + return + self._closed = True + self.pgconn.finish() + + @overload + def cursor(self, *, binary: bool = False) -> Cursor[Row]: + ... + + @overload + def cursor( + self, *, binary: bool = False, row_factory: RowFactory[CursorRow] + ) -> Cursor[CursorRow]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> ServerCursor[Row]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + row_factory: RowFactory[CursorRow], + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> ServerCursor[CursorRow]: + ... + + def cursor( + self, + name: str = "", + *, + binary: bool = False, + row_factory: Optional[RowFactory[Any]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> Union[Cursor[Any], ServerCursor[Any]]: + """ + Return a new cursor to send commands and queries to the connection. + """ + self._check_connection_ok() + + if not row_factory: + row_factory = self.row_factory + + cur: Union[Cursor[Any], ServerCursor[Any]] + if name: + cur = self.server_cursor_factory( + self, + name=name, + row_factory=row_factory, + scrollable=scrollable, + withhold=withhold, + ) + else: + cur = self.cursor_factory(self, row_factory=row_factory) + + if binary: + cur.format = BINARY + + return cur + + def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: bool = False, + ) -> Cursor[Row]: + """Execute a query and return a cursor to read its results.""" + try: + cur = self.cursor() + if binary: + cur.format = BINARY + + return cur.execute(query, params, prepare=prepare) + + except e.Error as ex: + raise ex.with_traceback(None) + + def commit(self) -> None: + """Commit any pending transaction to the database.""" + with self.lock: + self.wait(self._commit_gen()) + + def rollback(self) -> None: + """Roll back to the start of any pending transaction.""" + with self.lock: + self.wait(self._rollback_gen()) + + @contextmanager + def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> Iterator[Transaction]: + """ + Start a context block with a new transaction or nested transaction. + + :param savepoint_name: Name of the savepoint used to manage a nested + transaction. If `!None`, one will be chosen automatically. + :param force_rollback: Roll back the transaction at the end of the + block even if there were no error (e.g. to try a no-op process). + :rtype: Transaction + """ + tx = Transaction(self, savepoint_name, force_rollback) + if self._pipeline: + with self.pipeline(), tx, self.pipeline(): + yield tx + else: + with tx: + yield tx + + def notifies(self) -> Generator[Notify, None, None]: + """ + Yield `Notify` objects as soon as they are received from the database. + """ + while True: + with self.lock: + try: + ns = self.wait(notifies(self.pgconn)) + except e.Error as ex: + raise ex.with_traceback(None) + enc = pgconn_encoding(self.pgconn) + for pgn in ns: + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + yield n + + @contextmanager + def pipeline(self) -> Iterator[Pipeline]: + """Switch the connection into pipeline mode.""" + with self.lock: + self._check_connection_ok() + + pipeline = self._pipeline + if pipeline is None: + # WARNING: reference loop, broken ahead. + pipeline = self._pipeline = Pipeline(self) + + try: + with pipeline: + yield pipeline + finally: + if pipeline.level == 0: + with self.lock: + assert pipeline is self._pipeline + self._pipeline = None + + def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + """ + Consume a generator operating on the connection. + + The function must be used on generators that don't change connection + fd (i.e. not on connect and reset). + """ + try: + return waiting.wait(gen, self.pgconn.socket, timeout=timeout) + except KeyboardInterrupt: + # On Ctrl-C, try to cancel the query in the server, otherwise + # the connection will remain stuck in ACTIVE state. + c = self.pgconn.get_cancel() + c.cancel() + try: + waiting.wait(gen, self.pgconn.socket, timeout=timeout) + except e.QueryCanceled: + pass # as expected + raise + + @classmethod + def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + """Consume a connection generator.""" + return waiting.wait_conn(gen, timeout=timeout) + + def _set_autocommit(self, value: bool) -> None: + with self.lock: + self.wait(self._set_autocommit_gen(value)) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + with self.lock: + self.wait(self._set_isolation_level_gen(value)) + + def _set_read_only(self, value: Optional[bool]) -> None: + with self.lock: + self.wait(self._set_read_only_gen(value)) + + def _set_deferrable(self, value: Optional[bool]) -> None: + with self.lock: + self.wait(self._set_deferrable_gen(value)) + + def tpc_begin(self, xid: Union[Xid, str]) -> None: + """ + Begin a TPC transaction with the given transaction ID `!xid`. + """ + with self.lock: + self.wait(self._tpc_begin_gen(xid)) + + def tpc_prepare(self) -> None: + """ + Perform the first phase of a transaction started with `tpc_begin()`. + """ + try: + with self.lock: + self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + """ + Commit a prepared two-phase transaction. + """ + with self.lock: + self.wait(self._tpc_finish_gen("COMMIT", xid)) + + def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + """ + Roll back a prepared two-phase transaction. + """ + with self.lock: + self.wait(self._tpc_finish_gen("ROLLBACK", xid)) + + def tpc_recover(self) -> List[Xid]: + self._check_tpc() + status = self.info.transaction_status + with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + cur.execute(Xid._get_recover_query()) + res = cur.fetchall() + + if status == IDLE and self.info.transaction_status == INTRANS: + self.rollback() + + return res diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py new file mode 100644 index 0000000..aa02dc0 --- /dev/null +++ b/psycopg/psycopg/connection_async.py @@ -0,0 +1,436 @@ +""" +psycopg async connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import asyncio +import logging +from types import TracebackType +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional +from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING +from contextlib import asynccontextmanager + +from . import pq +from . import errors as e +from . import waiting +from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV +from ._tpc import Xid +from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row +from .adapt import AdaptersMap +from ._enums import IsolationLevel +from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async +from ._pipeline import AsyncPipeline +from ._encodings import pgconn_encoding +from .connection import BaseConnection, CursorRow, Notify +from .generators import notifies +from .transaction import AsyncTransaction +from .cursor_async import AsyncCursor +from .server_cursor import AsyncServerCursor + +if TYPE_CHECKING: + from .pq.abc import PGconn + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + +logger = logging.getLogger("psycopg") + + +class AsyncConnection(BaseConnection[Row]): + """ + Asynchronous wrapper for a connection to the database. + """ + + __module__ = "psycopg" + + cursor_factory: Type[AsyncCursor[Row]] + server_cursor_factory: Type[AsyncServerCursor[Row]] + row_factory: AsyncRowFactory[Row] + _pipeline: Optional[AsyncPipeline] + _Self = TypeVar("_Self", bound="AsyncConnection[Any]") + + def __init__( + self, + pgconn: "PGconn", + row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row), + ): + super().__init__(pgconn) + self.row_factory = row_factory + self.lock = asyncio.Lock() + self.cursor_factory = AsyncCursor + self.server_cursor_factory = AsyncServerCursor + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: AsyncRowFactory[Row], + cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[Row]": + # TODO: returned type should be _Self. See #308. + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[AsyncCursor[Any]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[TupleRow]": + ... + + @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + context: Optional[AdaptContext] = None, + row_factory: Optional[AsyncRowFactory[Row]] = None, + cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, + **kwargs: Any, + ) -> "AsyncConnection[Any]": + + if sys.platform == "win32": + loop = asyncio.get_running_loop() + if isinstance(loop, asyncio.ProactorEventLoop): + raise e.InterfaceError( + "Psycopg cannot use the 'ProactorEventLoop' to run in async" + " mode. Please use a compatible event loop, for instance by" + " setting 'asyncio.set_event_loop_policy" + "(WindowsSelectorEventLoopPolicy())'" + ) + + params = await cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + + try: + rv = await cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit), + timeout=params["connect_timeout"], + ) + except e.Error as ex: + raise ex.with_traceback(None) + + if row_factory: + rv.row_factory = row_factory + if cursor_factory: + rv.cursor_factory = cursor_factory + if context: + rv._adapters = AdaptersMap(context.adapters) + rv.prepare_threshold = prepare_threshold + return rv + + async def __aenter__(self: _Self) -> _Self: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.closed: + return + + if exc_type: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + await self.rollback() + except Exception as exc2: + logger.warning( + "error ignored in rollback on %s: %s", + self, + exc2, + ) + else: + await self.commit() + + # Close the connection only if it doesn't belong to a pool. + if not getattr(self, "_pool", None): + await self.close() + + @classmethod + async def _get_connection_params( + cls, conninfo: str, **kwargs: Any + ) -> Dict[str, Any]: + """Manipulate connection parameters before connecting. + + .. versionchanged:: 3.1 + Unlike the sync counterpart, perform non-blocking address + resolution and populate the ``hostaddr`` connection parameter, + unless the user has provided one themselves. See + `~psycopg._dns.resolve_hostaddr_async()` for details. + + """ + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + # Resolve host addresses in non-blocking way + params = await resolve_hostaddr_async(params) + + return params + + async def close(self) -> None: + if self.closed: + return + self._closed = True + self.pgconn.finish() + + @overload + def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: + ... + + @overload + def cursor( + self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow] + ) -> AsyncCursor[CursorRow]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> AsyncServerCursor[Row]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + row_factory: AsyncRowFactory[CursorRow], + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> AsyncServerCursor[CursorRow]: + ... + + def cursor( + self, + name: str = "", + *, + binary: bool = False, + row_factory: Optional[AsyncRowFactory[Any]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]: + """ + Return a new `AsyncCursor` to send commands and queries to the connection. + """ + self._check_connection_ok() + + if not row_factory: + row_factory = self.row_factory + + cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]] + if name: + cur = self.server_cursor_factory( + self, + name=name, + row_factory=row_factory, + scrollable=scrollable, + withhold=withhold, + ) + else: + cur = self.cursor_factory(self, row_factory=row_factory) + + if binary: + cur.format = BINARY + + return cur + + async def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: bool = False, + ) -> AsyncCursor[Row]: + try: + cur = self.cursor() + if binary: + cur.format = BINARY + + return await cur.execute(query, params, prepare=prepare) + + except e.Error as ex: + raise ex.with_traceback(None) + + async def commit(self) -> None: + async with self.lock: + await self.wait(self._commit_gen()) + + async def rollback(self) -> None: + async with self.lock: + await self.wait(self._rollback_gen()) + + @asynccontextmanager + async def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> AsyncIterator[AsyncTransaction]: + """ + Start a context block with a new transaction or nested transaction. + + :rtype: AsyncTransaction + """ + tx = AsyncTransaction(self, savepoint_name, force_rollback) + if self._pipeline: + async with self.pipeline(), tx, self.pipeline(): + yield tx + else: + async with tx: + yield tx + + async def notifies(self) -> AsyncGenerator[Notify, None]: + while True: + async with self.lock: + try: + ns = await self.wait(notifies(self.pgconn)) + except e.Error as ex: + raise ex.with_traceback(None) + enc = pgconn_encoding(self.pgconn) + for pgn in ns: + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + yield n + + @asynccontextmanager + async def pipeline(self) -> AsyncIterator[AsyncPipeline]: + """Context manager to switch the connection into pipeline mode.""" + async with self.lock: + self._check_connection_ok() + + pipeline = self._pipeline + if pipeline is None: + # WARNING: reference loop, broken ahead. + pipeline = self._pipeline = AsyncPipeline(self) + + try: + async with pipeline: + yield pipeline + finally: + if pipeline.level == 0: + async with self.lock: + assert pipeline is self._pipeline + self._pipeline = None + + async def wait(self, gen: PQGen[RV]) -> RV: + try: + return await waiting.wait_async(gen, self.pgconn.socket) + except KeyboardInterrupt: + # TODO: this doesn't seem to work as it does for sync connections + # see tests/test_concurrency_async.py::test_ctrl_c + # In the test, the code doesn't reach this branch. + + # On Ctrl-C, try to cancel the query in the server, otherwise + # otherwise the connection will be stuck in ACTIVE state + c = self.pgconn.get_cancel() + c.cancel() + try: + await waiting.wait_async(gen, self.pgconn.socket) + except e.QueryCanceled: + pass # as expected + raise + + @classmethod + async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + return await waiting.wait_conn_async(gen, timeout) + + def _set_autocommit(self, value: bool) -> None: + self._no_set_async("autocommit") + + async def set_autocommit(self, value: bool) -> None: + """Async version of the `~Connection.autocommit` setter.""" + async with self.lock: + await self.wait(self._set_autocommit_gen(value)) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + self._no_set_async("isolation_level") + + async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + """Async version of the `~Connection.isolation_level` setter.""" + async with self.lock: + await self.wait(self._set_isolation_level_gen(value)) + + def _set_read_only(self, value: Optional[bool]) -> None: + self._no_set_async("read_only") + + async def set_read_only(self, value: Optional[bool]) -> None: + """Async version of the `~Connection.read_only` setter.""" + async with self.lock: + await self.wait(self._set_read_only_gen(value)) + + def _set_deferrable(self, value: Optional[bool]) -> None: + self._no_set_async("deferrable") + + async def set_deferrable(self, value: Optional[bool]) -> None: + """Async version of the `~Connection.deferrable` setter.""" + async with self.lock: + await self.wait(self._set_deferrable_gen(value)) + + def _no_set_async(self, attribute: str) -> None: + raise AttributeError( + f"'the {attribute!r} property is read-only on async connections:" + f" please use 'await .set_{attribute}()' instead." + ) + + async def tpc_begin(self, xid: Union[Xid, str]) -> None: + async with self.lock: + await self.wait(self._tpc_begin_gen(xid)) + + async def tpc_prepare(self) -> None: + try: + async with self.lock: + await self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("commit", xid)) + + async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("rollback", xid)) + + async def tpc_recover(self) -> List[Xid]: + self._check_tpc() + status = self.info.transaction_status + async with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + await cur.execute(Xid._get_recover_query()) + res = await cur.fetchall() + + if status == IDLE and self.info.transaction_status == INTRANS: + await self.rollback() + + return res diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py new file mode 100644 index 0000000..3b21f83 --- /dev/null +++ b/psycopg/psycopg/conninfo.py @@ -0,0 +1,378 @@ +""" +Functions to manipulate conninfo strings +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +import socket +import asyncio +from typing import Any, Dict, List, Optional +from pathlib import Path +from datetime import tzinfo +from functools import lru_cache +from ipaddress import ip_address + +from . import pq +from . import errors as e +from ._tz import get_tzinfo +from ._encodings import pgconn_encoding + + +def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: + """ + Merge a string and keyword params into a single conninfo string. + + :param conninfo: A `connection string`__ as accepted by PostgreSQL. + :param kwargs: Parameters overriding the ones specified in `!conninfo`. + :return: A connection string valid for PostgreSQL, with the `!kwargs` + parameters merged. + + Raise `~psycopg.ProgrammingError` if the input doesn't make a valid + conninfo string. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING + """ + if not conninfo and not kwargs: + return "" + + # If no kwarg specified don't mung the conninfo but check if it's correct. + # Make sure to return a string, not a subtype, to avoid making Liskov sad. + if not kwargs: + _parse_conninfo(conninfo) + return str(conninfo) + + # Override the conninfo with the parameters + # Drop the None arguments + kwargs = {k: v for (k, v) in kwargs.items() if v is not None} + + if conninfo: + tmp = conninfo_to_dict(conninfo) + tmp.update(kwargs) + kwargs = tmp + + conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items()) + + # Verify the result is valid + _parse_conninfo(conninfo) + + return conninfo + + +def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]: + """ + Convert the `!conninfo` string into a dictionary of parameters. + + :param conninfo: A `connection string`__ as accepted by PostgreSQL. + :param kwargs: Parameters overriding the ones specified in `!conninfo`. + :return: Dictionary with the parameters parsed from `!conninfo` and + `!kwargs`. + + Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection + string. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING + """ + opts = _parse_conninfo(conninfo) + rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None} + for k, v in kwargs.items(): + if v is not None: + rv[k] = v + return rv + + +def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]: + """ + Verify that `!conninfo` is a valid connection string. + + Raise ProgrammingError if the string is not valid. + + Return the result of pq.Conninfo.parse() on success. + """ + try: + return pq.Conninfo.parse(conninfo.encode()) + except e.OperationalError as ex: + raise e.ProgrammingError(str(ex)) + + +re_escape = re.compile(r"([\\'])") +re_space = re.compile(r"\s") + + +def _param_escape(s: str) -> str: + """ + Apply the escaping rule required by PQconnectdb + """ + if not s: + return "''" + + s = re_escape.sub(r"\\\1", s) + if re_space.search(s): + s = "'" + s + "'" + + return s + + +class ConnectionInfo: + """Allow access to information about the connection.""" + + __module__ = "psycopg" + + def __init__(self, pgconn: pq.abc.PGconn): + self.pgconn = pgconn + + @property + def vendor(self) -> str: + """A string representing the database vendor connected to.""" + return "PostgreSQL" + + @property + def host(self) -> str: + """The server host name of the active connection. See :pq:`PQhost()`.""" + return self._get_pgconn_attr("host") + + @property + def hostaddr(self) -> str: + """The server IP address of the connection. See :pq:`PQhostaddr()`.""" + return self._get_pgconn_attr("hostaddr") + + @property + def port(self) -> int: + """The port of the active connection. See :pq:`PQport()`.""" + return int(self._get_pgconn_attr("port")) + + @property + def dbname(self) -> str: + """The database name of the connection. See :pq:`PQdb()`.""" + return self._get_pgconn_attr("db") + + @property + def user(self) -> str: + """The user name of the connection. See :pq:`PQuser()`.""" + return self._get_pgconn_attr("user") + + @property + def password(self) -> str: + """The password of the connection. See :pq:`PQpass()`.""" + return self._get_pgconn_attr("password") + + @property + def options(self) -> str: + """ + The command-line options passed in the connection request. + See :pq:`PQoptions`. + """ + return self._get_pgconn_attr("options") + + def get_parameters(self) -> Dict[str, str]: + """Return the connection parameters values. + + Return all the parameters set to a non-default value, which might come + either from the connection string and parameters passed to + `~Connection.connect()` or from environment variables. The password + is never returned (you can read it using the `password` attribute). + """ + pyenc = self.encoding + + # Get the known defaults to avoid reporting them + defaults = { + i.keyword: i.compiled + for i in pq.Conninfo.get_defaults() + if i.compiled is not None + } + # Not returned by the libq. Bug? Bet we're using SSH. + defaults.setdefault(b"channel_binding", b"prefer") + defaults[b"passfile"] = str(Path.home() / ".pgpass").encode() + + return { + i.keyword.decode(pyenc): i.val.decode(pyenc) + for i in self.pgconn.info + if i.val is not None + and i.keyword != b"password" + and i.val != defaults.get(i.keyword) + } + + @property + def dsn(self) -> str: + """Return the connection string to connect to the database. + + The string contains all the parameters set to a non-default value, + which might come either from the connection string and parameters + passed to `~Connection.connect()` or from environment variables. The + password is never returned (you can read it using the `password` + attribute). + """ + return make_conninfo(**self.get_parameters()) + + @property + def status(self) -> pq.ConnStatus: + """The status of the connection. See :pq:`PQstatus()`.""" + return pq.ConnStatus(self.pgconn.status) + + @property + def transaction_status(self) -> pq.TransactionStatus: + """ + The current in-transaction status of the session. + See :pq:`PQtransactionStatus()`. + """ + return pq.TransactionStatus(self.pgconn.transaction_status) + + @property + def pipeline_status(self) -> pq.PipelineStatus: + """ + The current pipeline status of the client. + See :pq:`PQpipelineStatus()`. + """ + return pq.PipelineStatus(self.pgconn.pipeline_status) + + def parameter_status(self, param_name: str) -> Optional[str]: + """ + Return a parameter setting of the connection. + + Return `None` is the parameter is unknown. + """ + res = self.pgconn.parameter_status(param_name.encode(self.encoding)) + return res.decode(self.encoding) if res is not None else None + + @property + def server_version(self) -> int: + """ + An integer representing the server version. See :pq:`PQserverVersion()`. + """ + return self.pgconn.server_version + + @property + def backend_pid(self) -> int: + """ + The process ID (PID) of the backend process handling this connection. + See :pq:`PQbackendPID()`. + """ + return self.pgconn.backend_pid + + @property + def error_message(self) -> str: + """ + The error message most recently generated by an operation on the connection. + See :pq:`PQerrorMessage()`. + """ + return self._get_pgconn_attr("error_message") + + @property + def timezone(self) -> tzinfo: + """The Python timezone info of the connection's timezone.""" + return get_tzinfo(self.pgconn) + + @property + def encoding(self) -> str: + """The Python codec name of the connection's client encoding.""" + return pgconn_encoding(self.pgconn) + + def _get_pgconn_attr(self, name: str) -> str: + value: bytes = getattr(self.pgconn, name) + return value.decode(self.encoding) + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses dynamically. + + The function may change the input ``host``, ``hostname``, ``port`` to allow + connecting without further DNS lookups, eventually removing hosts that are + not resolved, keeping the lists of hosts and ports consistent. + + Raise `~psycopg.OperationalError` if connection is not possible (e.g. no + host resolve, inconsistent lists length). + """ + hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", "")) + if hostaddr_arg: + # Already resolved + return params + + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + if not host_arg: + # Nothing to resolve + return params + + hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + ports_in = port_arg.split(",") if port_arg else [] + default_port = "5432" + + if len(ports_in) == 1: + # If only one port is specified, the libpq will apply it to all + # the hosts, so don't mangle it. + default_port = ports_in.pop() + + elif len(ports_in) > 1: + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" + ) + ports_out = [] + + hosts_out = [] + hostaddr_out = [] + loop = asyncio.get_running_loop() + for i, host in enumerate(hosts_in): + if not host or host.startswith("/") or host[1:2] == ":": + # Local path + hosts_out.append(host) + hostaddr_out.append("") + if ports_in: + ports_out.append(ports_in[i]) + continue + + # If the host is already an ip address don't try to resolve it + if is_ip_address(host): + hosts_out.append(host) + hostaddr_out.append(host) + if ports_in: + ports_out.append(ports_in[i]) + continue + + try: + port = ports_in[i] if ports_in else default_port + ans = await loop.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + except OSError as ex: + last_exc = ex + else: + for item in ans: + hosts_out.append(host) + hostaddr_out.append(item[4][0]) + if ports_in: + ports_out.append(ports_in[i]) + + # Throw an exception if no host could be resolved + if not hosts_out: + raise e.OperationalError(str(last_exc)) + + out = params.copy() + out["host"] = ",".join(hosts_out) + out["hostaddr"] = ",".join(hostaddr_out) + if ports_in: + out["port"] = ",".join(ports_out) + + return out + + +@lru_cache() +def is_ip_address(s: str) -> bool: + """Return True if the string represent a valid ip address.""" + try: + ip_address(s) + except ValueError: + return False + return True diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py new file mode 100644 index 0000000..7514306 --- /dev/null +++ b/psycopg/psycopg/copy.py @@ -0,0 +1,904 @@ +""" +psycopg copy support +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import queue +import struct +import asyncio +import threading +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO +from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING + +from . import pq +from . import adapt +from . import errors as e +from .abc import Buffer, ConnectionType, PQGen, Transformer +from ._compat import create_task +from ._cmodule import _psycopg +from ._encodings import pgconn_encoding +from .generators import copy_from, copy_to, copy_end + +if TYPE_CHECKING: + from .cursor import BaseCursor, Cursor + from .cursor_async import AsyncCursor + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + +PY_TEXT = adapt.PyFormat.TEXT +PY_BINARY = adapt.PyFormat.BINARY + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +COPY_IN = pq.ExecStatus.COPY_IN +COPY_OUT = pq.ExecStatus.COPY_OUT + +ACTIVE = pq.TransactionStatus.ACTIVE + +# Size of data to accumulate before sending it down the network. We fill a +# buffer this size field by field, and when it passes the threshold size +# we ship it, so it may end up being bigger than this. +BUFFER_SIZE = 32 * 1024 + +# Maximum data size we want to queue to send to the libpq copy. Sending a +# buffer too big to be handled can cause an infinite loop in the libpq +# (#255) so we want to split it in more digestable chunks. +MAX_BUFFER_SIZE = 4 * BUFFER_SIZE +# Note: making this buffer too large, e.g. +# MAX_BUFFER_SIZE = 1024 * 1024 +# makes operations *way* slower! Probably triggering some quadraticity +# in the libpq memory management and data sending. + +# Max size of the write queue of buffers. More than that copy will block +# Each buffer should be around BUFFER_SIZE size. +QUEUE_SIZE = 1024 + + +class BaseCopy(Generic[ConnectionType]): + """ + Base implementation for the copy user interface. + + Two subclasses expose real methods with the sync/async differences. + + The difference between the text and binary format is managed by two + different `Formatter` subclasses. + + Writing (the I/O part) is implemented in the subclasses by a `Writer` or + `AsyncWriter` instance. Normally writing implies sending copy data to a + database, but a different writer might be chosen, e.g. to stream data into + a file for later use. + """ + + _Self = TypeVar("_Self", bound="BaseCopy[Any]") + + formatter: "Formatter" + + def __init__( + self, + cursor: "BaseCursor[ConnectionType, Any]", + *, + binary: Optional[bool] = None, + ): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + result = cursor.pgresult + if result: + self._direction = result.status + if self._direction != COPY_IN and self._direction != COPY_OUT: + raise e.ProgrammingError( + "the cursor should have performed a COPY operation;" + f" its status is {pq.ExecStatus(self._direction).name} instead" + ) + else: + self._direction = COPY_IN + + if binary is None: + binary = bool(result and result.binary_tuples) + + tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) + if binary: + self.formatter = BinaryFormatter(tx) + else: + self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) + + self._finished = False + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + def _enter(self) -> None: + if self._finished: + raise TypeError("copy blocks can be used only once") + + def set_types(self, types: Sequence[Union[int, str]]) -> None: + """ + Set the types expected in a COPY operation. + + The types must be specified as a sequence of oid or PostgreSQL type + names (e.g. ``int4``, ``timestamptz[]``). + + This operation overcomes the lack of metadata returned by PostgreSQL + when a COPY operation begins: + + - On :sql:`COPY TO`, `!set_types()` allows to specify what types the + operation returns. If `!set_types()` is not used, the data will be + returned as unparsed strings or bytes instead of Python objects. + + - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the + database expects. This is especially useful in binary copy, because + PostgreSQL will apply no cast rule. + + """ + registry = self.cursor.adapters.types + oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] + + if self._direction == COPY_IN: + self.formatter.transformer.set_dumper_types(oids, self.formatter.format) + else: + self.formatter.transformer.set_loader_types(oids, self.formatter.format) + + # High level copy protocol generators (state change of the Copy object) + + def _read_gen(self) -> PQGen[Buffer]: + if self._finished: + return memoryview(b"") + + res = yield from copy_from(self._pgconn) + if isinstance(res, memoryview): + return res + + # res is the final PGresult + self._finished = True + + # This result is a COMMAND_OK which has info about the number of rows + # returned, but not about the columns, which is instead an information + # that was received on the COPY_OUT result at the beginning of COPY. + # So, don't replace the results in the cursor, just update the rowcount. + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + return memoryview(b"") + + def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]: + data = yield from self._read_gen() + if not data: + return None + + row = self.formatter.parse_row(data) + if row is None: + # Get the final result to finish the copy operation + yield from self._read_gen() + self._finished = True + return None + + return row + + def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: + if not exc: + return + + if self._pgconn.transaction_status != ACTIVE: + # The server has already finished to send copy data. The connection + # is already in a good state. + return + + # Throw a cancel to the server, then consume the rest of the copy data + # (which might or might not have been already transferred entirely to + # the client, so we won't necessary see the exception associated with + # canceling). + self.connection.cancel() + try: + while (yield from self._read_gen()): + pass + except e.QueryCanceled: + pass + + +class Copy(BaseCopy["Connection[Any]"]): + """Manage a :sql:`COPY` operation. + + :param cursor: the cursor where the operation is performed. + :param binary: if `!True`, write binary format. + :param writer: the object to write to destination. If not specified, write + to the `!cursor` connection. + + Choosing `!binary` is not necessary if the cursor has executed a + :sql:`COPY` operation, because the operation result describes the format + too. The parameter is useful when a `!Copy` object is created manually and + no operation is performed on the cursor, such as when using ``writer=``\\ + `~psycopg.copy.FileWriter`. + + """ + + __module__ = "psycopg" + + writer: "Writer" + + def __init__( + self, + cursor: "Cursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["Writer"] = None, + ): + super().__init__(cursor, binary=binary) + if not writer: + writer = LibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.finish(exc_val) + + # End user sync interface + + def __iter__(self) -> Iterator[Buffer]: + """Implement block-by-block iteration on :sql:`COPY TO`.""" + while True: + data = self.read() + if not data: + break + yield data + + def read(self) -> Buffer: + """ + Read an unparsed row after a :sql:`COPY TO` operation. + + Return an empty string when the data is finished. + """ + return self.connection.wait(self._read_gen()) + + def rows(self) -> Iterator[Tuple[Any, ...]]: + """ + Iterate on the result of a :sql:`COPY TO` operation record by record. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + while True: + record = self.read_row() + if record is None: + break + yield record + + def read_row(self) -> Optional[Tuple[Any, ...]]: + """ + Read a parsed row of data from a table after a :sql:`COPY TO` operation. + + Return `!None` when the data is finished. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + return self.connection.wait(self._read_row_gen()) + + def write(self, buffer: Union[Buffer, str]) -> None: + """ + Write a block of data to a table after a :sql:`COPY FROM` operation. + + If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In + text mode it can be either `!bytes` or `!str`. + """ + data = self.formatter.write(buffer) + if data: + self._write(data) + + def write_row(self, row: Sequence[Any]) -> None: + """Write a record to a table after a :sql:`COPY FROM` operation.""" + data = self.formatter.write_row(row) + if data: + self._write(data) + + def finish(self, exc: Optional[BaseException]) -> None: + """Terminate the copy operation and free the resources allocated. + + You shouldn't need to call this function yourself: it is usually called + by exit. It is available if, despite what is documented, you end up + using the `Copy` object outside a block. + """ + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + self._write(data) + self.writer.finish(exc) + self._finished = True + else: + self.connection.wait(self._end_copy_out_gen(exc)) + + +class Writer(ABC): + """ + A class to write copy data somewhere. + """ + + @abstractmethod + def write(self, data: Buffer) -> None: + """ + Write some data to destination. + """ + ... + + def finish(self, exc: Optional[BaseException] = None) -> None: + """ + Called when write operations are finished. + + If operations finished with an error, it will be passed to ``exc``. + """ + pass + + +class LibpqWriter(Writer): + """ + A `Writer` to write copy data to a Postgres database. + """ + + def __init__(self, cursor: "Cursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = self.connection.wait(copy_end(self._pgconn, bmsg)) + self.cursor._results = [res] + + +class QueuedLibpqDriver(LibpqWriter): + """ + A writer using a buffer to queue data to write to a Postgres database. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, cursor: "Cursor[Any]"): + super().__init__(cursor) + + self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[threading.Thread] = None + self._worker_error: Optional[BaseException] = None + + def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value, or in case + of error. + + The function is designed to be run in a separate thread. + """ + try: + while True: + data = self._queue.get(block=True, timeout=24 * 60 * 60) + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex + + def write(self, data: Buffer) -> None: + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = threading.Thread(target=self.worker) + self._worker.daemon = True + self._worker.start() + + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + def finish(self, exc: Optional[BaseException] = None) -> None: + self._queue.put(b"") + + if self._worker: + self._worker.join() + self._worker = None # break the loop + + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + + super().finish(exc) + + +class FileWriter(Writer): + """ + A `Writer` to write copy data to a file-like object. + + :param file: the file where to write copy data. It must be open for writing + in binary mode. + """ + + def __init__(self, file: IO[bytes]): + self.file = file + + def write(self, data: Buffer) -> None: + self.file.write(data) # type: ignore[arg-type] + + +class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): + """Manage an asynchronous :sql:`COPY` operation.""" + + __module__ = "psycopg" + + writer: "AsyncWriter" + + def __init__( + self, + cursor: "AsyncCursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["AsyncWriter"] = None, + ): + super().__init__(cursor, binary=binary) + + if not writer: + writer = AsyncLibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.finish(exc_val) + + async def __aiter__(self) -> AsyncIterator[Buffer]: + while True: + data = await self.read() + if not data: + break + yield data + + async def read(self) -> Buffer: + return await self.connection.wait(self._read_gen()) + + async def rows(self) -> AsyncIterator[Tuple[Any, ...]]: + while True: + record = await self.read_row() + if record is None: + break + yield record + + async def read_row(self) -> Optional[Tuple[Any, ...]]: + return await self.connection.wait(self._read_row_gen()) + + async def write(self, buffer: Union[Buffer, str]) -> None: + data = self.formatter.write(buffer) + if data: + await self._write(data) + + async def write_row(self, row: Sequence[Any]) -> None: + data = self.formatter.write_row(row) + if data: + await self._write(data) + + async def finish(self, exc: Optional[BaseException]) -> None: + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + await self._write(data) + await self.writer.finish(exc) + self._finished = True + else: + await self.connection.wait(self._end_copy_out_gen(exc)) + + +class AsyncWriter(ABC): + """ + A class to write copy data somewhere (for async connections). + """ + + @abstractmethod + async def write(self, data: Buffer) -> None: + ... + + async def finish(self, exc: Optional[BaseException] = None) -> None: + pass + + +class AsyncLibpqWriter(AsyncWriter): + """ + An `AsyncWriter` to write copy data to a Postgres database. + """ + + def __init__(self, cursor: "AsyncCursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + async def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + async def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = await self.connection.wait(copy_end(self._pgconn, bmsg)) + self.cursor._results = [res] + + +class AsyncQueuedLibpqWriter(AsyncLibpqWriter): + """ + An `AsyncWriter` using a buffer to queue data to write. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, cursor: "AsyncCursor[Any]"): + super().__init__(cursor) + + self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[asyncio.Future[None]] = None + + async def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value. + + The function is designed to be run in a separate task. + """ + while True: + data = await self._queue.get() + if not data: + break + await self.connection.wait(copy_to(self._pgconn, data)) + + async def write(self, data: Buffer) -> None: + if not self._worker: + self._worker = create_task(self.worker()) + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + async def finish(self, exc: Optional[BaseException] = None) -> None: + await self._queue.put(b"") + + if self._worker: + await asyncio.gather(self._worker) + self._worker = None # break reference loops if any + + await super().finish(exc) + + +class Formatter(ABC): + """ + A class which understand a copy format (text, binary). + """ + + format: pq.Format + + def __init__(self, transformer: Transformer): + self.transformer = transformer + self._write_buffer = bytearray() + self._row_mode = False # true if the user is using write_row() + + @abstractmethod + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + ... + + @abstractmethod + def write(self, buffer: Union[Buffer, str]) -> Buffer: + ... + + @abstractmethod + def write_row(self, row: Sequence[Any]) -> Buffer: + ... + + @abstractmethod + def end(self) -> Buffer: + ... + + +class TextFormatter(Formatter): + + format = TEXT + + def __init__(self, transformer: Transformer, encoding: str = "utf-8"): + super().__init__(transformer) + self._encoding = encoding + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if data: + return parse_row_text(data, self.transformer) + else: + return None + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + format_row_text(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + return data.encode(self._encoding) + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +class BinaryFormatter(Formatter): + + format = BINARY + + def __init__(self, transformer: Transformer): + super().__init__(transformer) + self._signature_sent = False + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if not self._signature_sent: + if data[: len(_binary_signature)] != _binary_signature: + raise e.DataError( + "binary copy doesn't start with the expected signature" + ) + self._signature_sent = True + data = data[len(_binary_signature) :] + + elif data == _binary_trailer: + return None + + return parse_row_binary(data, self.transformer) + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + if not self._signature_sent: + self._write_buffer += _binary_signature + self._signature_sent = True + + format_row_binary(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + # If we have sent no data we need to send the signature + # and the trailer + if not self._signature_sent: + self._write_buffer += _binary_signature + self._write_buffer += _binary_trailer + + elif self._row_mode: + # if we have sent data already, we have sent the signature + # too (either with the first row, or we assume that in + # block mode the signature is included). + # Write the trailer only if we are sending rows (with the + # assumption that who is copying binary data is sending the + # whole format). + self._write_buffer += _binary_trailer + + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + raise TypeError("cannot copy str data in binary mode: use bytes instead") + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +def _format_row_text( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for copy.""" + if out is None: + out = bytearray() + + if not row: + out += b"\n" + return out + + for item in row: + if item is not None: + dumper = tx.get_dumper(item, PY_TEXT) + b = dumper.dump(item) + out += _dump_re.sub(_dump_sub, b) + else: + out += rb"\N" + out += b"\t" + + out[-1:] = b"\n" + return out + + +def _format_row_binary( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for binary copy.""" + if out is None: + out = bytearray() + + out += _pack_int2(len(row)) + adapted = tx.dump_sequence(row, [PY_BINARY] * len(row)) + for b in adapted: + if b is not None: + out += _pack_int4(len(b)) + out += b + else: + out += _binary_null + + return out + + +def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + if not isinstance(data, bytes): + data = bytes(data) + fields = data.split(b"\t") + fields[-1] = fields[-1][:-1] # drop \n + row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] + return tx.load_sequence(row) + + +def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + row: List[Optional[Buffer]] = [] + nfields = _unpack_int2(data, 0)[0] + pos = 2 + for i in range(nfields): + length = _unpack_int4(data, pos)[0] + pos += 4 + if length >= 0: + row.append(data[pos : pos + length]) + pos += length + else: + row.append(None) + + return tx.load_sequence(row) + + +_pack_int2 = struct.Struct("!h").pack +_pack_int4 = struct.Struct("!i").pack +_unpack_int2 = struct.Struct("!h").unpack_from +_unpack_int4 = struct.Struct("!i").unpack_from + +_binary_signature = ( + b"PGCOPY\n\xff\r\n\0" # Signature + b"\x00\x00\x00\x00" # flags + b"\x00\x00\x00\x00" # extra length +) +_binary_trailer = b"\xff\xff" +_binary_null = b"\xff\xff\xff\xff" + +_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") +_dump_repl = { + b"\b": b"\\b", + b"\t": b"\\t", + b"\n": b"\\n", + b"\v": b"\\v", + b"\f": b"\\f", + b"\r": b"\\r", + b"\\": b"\\\\", +} + + +def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes: + return __map[m.group(0)] + + +_load_re = re.compile(b"\\\\[btnvfr\\\\]") +_load_repl = {v: k for k, v in _dump_repl.items()} + + +def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes: + return __map[m.group(0)] + + +# Override functions with fast versions if available +if _psycopg: + format_row_text = _psycopg.format_row_text + format_row_binary = _psycopg.format_row_binary + parse_row_text = _psycopg.parse_row_text + parse_row_binary = _psycopg.parse_row_binary + +else: + format_row_text = _format_row_text + format_row_binary = _format_row_binary + parse_row_text = _parse_row_text + parse_row_binary = _parse_row_binary diff --git a/psycopg/psycopg/crdb/__init__.py b/psycopg/psycopg/crdb/__init__.py new file mode 100644 index 0000000..323903a --- /dev/null +++ b/psycopg/psycopg/crdb/__init__.py @@ -0,0 +1,19 @@ +""" +CockroachDB support package. +""" + +# Copyright (C) 2022 The Psycopg Team + +from . import _types +from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo + +adapters = _types.adapters # exposed by the package +connect = CrdbConnection.connect + +_types.register_crdb_adapters(adapters) + +__all__ = [ + "AsyncCrdbConnection", + "CrdbConnection", + "CrdbConnectionInfo", +] diff --git a/psycopg/psycopg/crdb/_types.py b/psycopg/psycopg/crdb/_types.py new file mode 100644 index 0000000..5311e05 --- /dev/null +++ b/psycopg/psycopg/crdb/_types.py @@ -0,0 +1,163 @@ +""" +Types configuration specific for CockroachDB. +""" + +# Copyright (C) 2022 The Psycopg Team + +from enum import Enum +from .._typeinfo import TypeInfo, TypesRegistry + +from ..abc import AdaptContext, NoneType +from ..postgres import TEXT_OID +from .._adapters_map import AdaptersMap +from ..types.enum import EnumDumper, EnumBinaryDumper +from ..types.none import NoneDumper + +types = TypesRegistry() + +# Global adapter maps with PostgreSQL types configuration +adapters = AdaptersMap(types=types) + + +class CrdbEnumDumper(EnumDumper): + oid = TEXT_OID + + +class CrdbEnumBinaryDumper(EnumBinaryDumper): + oid = TEXT_OID + + +class CrdbNoneDumper(NoneDumper): + oid = TEXT_OID + + +def register_postgres_adapters(context: AdaptContext) -> None: + # Same adapters used by PostgreSQL, or a good starting point for customization + + from ..types import array, bool, composite, datetime + from ..types import numeric, string, uuid + + array.register_default_adapters(context) + bool.register_default_adapters(context) + composite.register_default_adapters(context) + datetime.register_default_adapters(context) + numeric.register_default_adapters(context) + string.register_default_adapters(context) + uuid.register_default_adapters(context) + + +def register_crdb_adapters(context: AdaptContext) -> None: + from .. import dbapi20 + from ..types import array + + register_postgres_adapters(context) + + # String must come after enum to map text oid -> string dumper + register_crdb_enum_adapters(context) + register_crdb_string_adapters(context) + register_crdb_json_adapters(context) + register_crdb_net_adapters(context) + register_crdb_none_adapters(context) + + dbapi20.register_dbapi20_adapters(adapters) + + array.register_all_arrays(adapters) + + +def register_crdb_string_adapters(context: AdaptContext) -> None: + from ..types import string + + # Dump strings with text oid instead of unknown. + # Unlike PostgreSQL, CRDB seems able to cast text to most types. + context.adapters.register_dumper(str, string.StrDumper) + context.adapters.register_dumper(str, string.StrBinaryDumper) + + +def register_crdb_enum_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper) + context.adapters.register_dumper(Enum, CrdbEnumDumper) + + +def register_crdb_json_adapters(context: AdaptContext) -> None: + from ..types import json + + adapters = context.adapters + + # CRDB doesn't have json/jsonb: both names map to the jsonb oid + adapters.register_dumper(json.Json, json.JsonbBinaryDumper) + adapters.register_dumper(json.Json, json.JsonbDumper) + + adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper) + adapters.register_dumper(json.Jsonb, json.JsonbDumper) + + adapters.register_loader("json", json.JsonLoader) + adapters.register_loader("jsonb", json.JsonbLoader) + adapters.register_loader("json", json.JsonBinaryLoader) + adapters.register_loader("jsonb", json.JsonbBinaryLoader) + + +def register_crdb_net_adapters(context: AdaptContext) -> None: + from ..types import net + + adapters = context.adapters + + adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper) + adapters.register_dumper(None, net.InetBinaryDumper) + adapters.register_loader("inet", net.InetLoader) + adapters.register_loader("inet", net.InetBinaryLoader) + + +def register_crdb_none_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(NoneType, CrdbNoneDumper) + + +for t in [ + TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb. + TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8 + TypeInfo('"char"', 18, 1002), # special case, not generated + # autogenerated: start + # Generated from CockroachDB 22.1.0 + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, regtype="boolean"), + TypeInfo("bpchar", 1042, 1014, regtype="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, regtype="real"), + TypeInfo("float8", 701, 1022, regtype="double precision"), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, regtype="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007), + TypeInfo("int8", 20, 1016, regtype="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("record", 2249, 2287), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("time", 1083, 1183, regtype="time without time zone"), + TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"), + TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"), + TypeInfo("timetz", 1266, 1270, regtype="time with time zone"), + TypeInfo("unknown", 705, 0), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, regtype="bit varying"), + TypeInfo("varchar", 1043, 1015, regtype="character varying"), + # autogenerated: end +]: + types.add(t) diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py new file mode 100644 index 0000000..6e79ed1 --- /dev/null +++ b/psycopg/psycopg/crdb/connection.py @@ -0,0 +1,186 @@ +""" +CockroachDB-specific connections. +""" + +# Copyright (C) 2022 The Psycopg Team + +import re +from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING + +from .. import errors as e +from ..abc import AdaptContext +from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow +from ..conninfo import ConnectionInfo +from ..connection import Connection +from .._adapters_map import AdaptersMap +from ..connection_async import AsyncConnection +from ._types import adapters + +if TYPE_CHECKING: + from ..pq.abc import PGconn + from ..cursor import Cursor + from ..cursor_async import AsyncCursor + + +class _CrdbConnectionMixin: + + _adapters: Optional[AdaptersMap] + pgconn: "PGconn" + + @classmethod + def is_crdb( + cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"] + ) -> bool: + """ + Return `!True` if the server connected to `!conn` is CockroachDB. + """ + if isinstance(conn, (Connection, AsyncConnection)): + conn = conn.pgconn + + return bool(conn.parameter_status(b"crdb_version")) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + # By default, use CockroachDB adapters map + self._adapters = AdaptersMap(adapters) + + return self._adapters + + @property + def info(self) -> "CrdbConnectionInfo": + return CrdbConnectionInfo(self.pgconn) + + def _check_tpc(self) -> None: + if self.is_crdb(self.pgconn): + raise e.NotSupportedError("CockroachDB doesn't support prepared statements") + + +class CrdbConnection(_CrdbConnectionMixin, Connection[Row]): + """ + Wrapper for a connection to a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[Row], + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[Row]": + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[TupleRow]": + ... + + @classmethod + def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]": + """ + Connect to a database server and return a new `CrdbConnection` instance. + """ + return super().connect(conninfo, **kwargs) # type: ignore[return-value] + + +class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]): + """ + Wrapper for an async connection to a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: AsyncRowFactory[Row], + cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[Row]": + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[TupleRow]": + ... + + @classmethod + async def connect( + cls, conninfo: str = "", **kwargs: Any + ) -> "AsyncCrdbConnection[Any]": + return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return] + + +class CrdbConnectionInfo(ConnectionInfo): + """ + `~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + @property + def vendor(self) -> str: + return "CockroachDB" + + @property + def server_version(self) -> int: + """ + Return the CockroachDB server version connected. + + Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210). + """ + sver = self.parameter_status("crdb_version") + if not sver: + raise e.InternalError("'crdb_version' parameter status not set") + + ver = self.parse_crdb_version(sver) + if ver is None: + raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}") + + return ver + + @classmethod + def parse_crdb_version(self, sver: str) -> Optional[int]: + m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) + if not m: + return None + + return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py new file mode 100644 index 0000000..42c3804 --- /dev/null +++ b/psycopg/psycopg/cursor.py @@ -0,0 +1,921 @@ +""" +psycopg cursor objects +""" + +# Copyright (C) 2020 The Psycopg Team + +from functools import partial +from types import TracebackType +from typing import Any, Generic, Iterable, Iterator, List +from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar +from typing import overload, TYPE_CHECKING +from contextlib import contextmanager + +from . import pq +from . import adapt +from . import errors as e +from .abc import ConnectionType, Query, Params, PQGen +from .copy import Copy, Writer as CopyWriter +from .rows import Row, RowMaker, RowFactory +from ._column import Column +from ._queries import PostgresQuery, PostgresClientQuery +from ._pipeline import Pipeline +from ._encodings import pgconn_encoding +from ._preparing import Prepare +from .generators import execute, fetch, send + +if TYPE_CHECKING: + from .abc import Transformer + from .pq.abc import PGconn, PGresult + from .connection import Connection + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +COPY_OUT = pq.ExecStatus.COPY_OUT +COPY_IN = pq.ExecStatus.COPY_IN +COPY_BOTH = pq.ExecStatus.COPY_BOTH +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR +SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE +PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED + +ACTIVE = pq.TransactionStatus.ACTIVE + + +class BaseCursor(Generic[ConnectionType, Row]): + __slots__ = """ + _conn format _adapters arraysize _closed _results pgresult _pos + _iresult _rowcount _query _tx _last_query _row_factory _make_row + _pgconn _execmany_returning + __weakref__ + """.split() + + ExecStatus = pq.ExecStatus + + _tx: "Transformer" + _make_row: RowMaker[Row] + _pgconn: "PGconn" + + def __init__(self, connection: ConnectionType): + self._conn = connection + self.format = TEXT + self._pgconn = connection.pgconn + self._adapters = adapt.AdaptersMap(connection.adapters) + self.arraysize = 1 + self._closed = False + self._last_query: Optional[Query] = None + self._reset() + + def _reset(self, reset_query: bool = True) -> None: + self._results: List["PGresult"] = [] + self.pgresult: Optional["PGresult"] = None + self._pos = 0 + self._iresult = 0 + self._rowcount = -1 + self._query: Optional[PostgresQuery] + # None if executemany() not executing, True/False according to returning state + self._execmany_returning: Optional[bool] = None + if reset_query: + self._query = None + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._pgconn) + if self._closed: + status = "closed" + elif self.pgresult: + status = pq.ExecStatus(self.pgresult.status).name + else: + status = "no result" + return f"<{cls} [{status}] {info} at 0x{id(self):x}>" + + @property + def connection(self) -> ConnectionType: + """The connection this cursor is using.""" + return self._conn + + @property + def adapters(self) -> adapt.AdaptersMap: + return self._adapters + + @property + def closed(self) -> bool: + """`True` if the cursor is closed.""" + return self._closed + + @property + def description(self) -> Optional[List[Column]]: + """ + A list of `Column` objects describing the current resultset. + + `!None` if the current resultset didn't return tuples. + """ + res = self.pgresult + + # We return columns if we have nfields, but also if we don't but + # the query said we got tuples (mostly to handle the super useful + # query "SELECT ;" + if res and ( + res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE + ): + return [Column(self, i) for i in range(res.nfields)] + else: + return None + + @property + def rowcount(self) -> int: + """Number of records affected by the precedent operation.""" + return self._rowcount + + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + tuples = self.pgresult and self.pgresult.status == TUPLES_OK + return self._pos if tuples else None + + def setinputsizes(self, sizes: Sequence[Any]) -> None: + # no-op + pass + + def setoutputsize(self, size: Any, column: Optional[int] = None) -> None: + # no-op + pass + + def nextset(self) -> Optional[bool]: + """ + Move to the result set of the next query executed through `executemany()` + or to the next result set if `execute()` returned more than one. + + Return `!True` if a new result is available, which will be the one + methods `!fetch*()` will operate on. + """ + if self._iresult < len(self._results) - 1: + self._select_current_result(self._iresult + 1) + return True + else: + return None + + @property + def statusmessage(self) -> Optional[str]: + """ + The command status tag from the last SQL command executed. + + `!None` if the cursor doesn't have a result available. + """ + msg = self.pgresult.command_status if self.pgresult else None + return msg.decode() if msg else None + + def _make_row_maker(self) -> RowMaker[Row]: + raise NotImplementedError + + # + # Generators for the high level operations on the cursor + # + # Like for sync/async connections, these are implemented as generators + # so that different concurrency strategies (threads,asyncio) can use their + # own way of waiting (or better, `connection.wait()`). + # + + def _execute_gen( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator implementing `Cursor.execute()`.""" + yield from self._start_query(query) + pgq = self._convert_query(query, params) + results = yield from self._maybe_prepare_gen( + pgq, prepare=prepare, binary=binary + ) + if self._conn._pipeline: + yield from self._conn._pipeline._communicate_gen() + else: + assert results is not None + self._check_results(results) + self._results = results + self._select_current_result(0) + + self._last_query = query + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _executemany_gen_pipeline( + self, query: Query, params_seq: Iterable[Params], returning: bool + ) -> PQGen[None]: + """ + Generator implementing `Cursor.executemany()` with pipelines available. + """ + pipeline = self._conn._pipeline + assert pipeline + + yield from self._start_query(query) + self._rowcount = 0 + + assert self._execmany_returning is None + self._execmany_returning = returning + + first = True + for params in params_seq: + if first: + pgq = self._convert_query(query, params) + self._query = pgq + first = False + else: + pgq.dump(params) + + yield from self._maybe_prepare_gen(pgq, prepare=True) + yield from pipeline._communicate_gen() + + self._last_query = query + + if returning: + yield from pipeline._fetch_gen(flush=True) + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _executemany_gen_no_pipeline( + self, query: Query, params_seq: Iterable[Params], returning: bool + ) -> PQGen[None]: + """ + Generator implementing `Cursor.executemany()` with pipelines not available. + """ + yield from self._start_query(query) + first = True + nrows = 0 + for params in params_seq: + if first: + pgq = self._convert_query(query, params) + self._query = pgq + first = False + else: + pgq.dump(params) + + results = yield from self._maybe_prepare_gen(pgq, prepare=True) + assert results is not None + self._check_results(results) + if returning: + self._results.extend(results) + + for res in results: + nrows += res.command_tuples or 0 + + if self._results: + self._select_current_result(0) + + # Override rowcount for the first result. Calls to nextset() will change + # it to the value of that result only, but we hope nobody will notice. + # You haven't read this comment. + self._rowcount = nrows + self._last_query = query + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _maybe_prepare_gen( + self, + pgq: PostgresQuery, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> PQGen[Optional[List["PGresult"]]]: + # Check if the query is prepared or needs preparing + prep, name = self._get_prepared(pgq, prepare) + if prep is Prepare.NO: + # The query must be executed without preparing + self._execute_send(pgq, binary=binary) + else: + # If the query is not already prepared, prepare it. + if prep is Prepare.SHOULD: + self._send_prepare(name, pgq) + if not self._conn._pipeline: + (result,) = yield from execute(self._pgconn) + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + # Then execute it. + self._send_query_prepared(name, pgq, binary=binary) + + # Update the prepare state of the query. + # If an operation requires to flush our prepared statements cache, + # it will be added to the maintenance commands to execute later. + key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name) + + if self._conn._pipeline: + queued = None + if key is not None: + queued = (key, prep, name) + self._conn._pipeline.result_queue.append((self, queued)) + return None + + # run the query + results = yield from execute(self._pgconn) + + if key is not None: + self._conn._prepared.validate(key, prep, name, results) + + return results + + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return self._conn._prepared.get(pgq, prepare) + + def _stream_send_gen( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator to send the query for `Cursor.stream()`.""" + yield from self._start_query(query) + pgq = self._convert_query(query, params) + self._execute_send(pgq, binary=binary, force_extended=True) + self._pgconn.set_single_row_mode() + self._last_query = query + yield from send(self._pgconn) + + def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]: + res = yield from fetch(self._pgconn) + if res is None: + return None + + status = res.status + if status == SINGLE_TUPLE: + self.pgresult = res + self._tx.set_pgresult(res, set_loaders=first) + if first: + self._make_row = self._make_row_maker() + return res + + elif status == TUPLES_OK or status == COMMAND_OK: + # End of single row results + while res: + res = yield from fetch(self._pgconn) + if status != TUPLES_OK: + raise e.ProgrammingError( + "the operation in stream() didn't produce a result" + ) + return None + + else: + # Errors, unexpected values + return self._raise_for_result(res) + + def _start_query(self, query: Optional[Query] = None) -> PQGen[None]: + """Generator to start the processing of a query. + + It is implemented as generator because it may send additional queries, + such as `begin`. + """ + if self.closed: + raise e.InterfaceError("the cursor is closed") + + self._reset() + if not self._last_query or (self._last_query is not query): + self._last_query = None + self._tx = adapt.Transformer(self) + yield from self._conn._start_query() + + def _start_copy_gen( + self, statement: Query, params: Optional[Params] = None + ) -> PQGen[None]: + """Generator implementing sending a command for `Cursor.copy().""" + + # The connection gets in an unrecoverable state if we attempt COPY in + # pipeline mode. Forbid it explicitly. + if self._conn._pipeline: + raise e.NotSupportedError("COPY cannot be used in pipeline mode") + + yield from self._start_query() + + # Merge the params client-side + if params: + pgq = PostgresClientQuery(self._tx) + pgq.convert(statement, params) + statement = pgq.query + + query = self._convert_query(statement) + + self._execute_send(query, binary=False) + results = yield from execute(self._pgconn) + if len(results) != 1: + raise e.ProgrammingError("COPY cannot be mixed with other operations") + + self._check_copy_result(results[0]) + self._results = results + self._select_current_result(0) + + def _execute_send( + self, + query: PostgresQuery, + *, + force_extended: bool = False, + binary: Optional[bool] = None, + ) -> None: + """ + Implement part of execute() before waiting common to sync and async. + + This is not a generator, but a normal non-blocking function. + """ + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + self._query = query + + if self._conn._pipeline: + # In pipeline mode always use PQsendQueryParams - see #314 + # Multiple statements in the same query are not allowed anyway. + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_params, + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) + ) + elif force_extended or query.params or fmt == BINARY: + self._pgconn.send_query_params( + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) + else: + # If we can, let's use simple query protocol, + # as it can execute more than one statement in a single query. + self._pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresQuery(self._tx) + pgq.convert(query, params) + return pgq + + def _check_results(self, results: List["PGresult"]) -> None: + """ + Verify that the results of a query are valid. + + Verify that the query returned at least one result and that they all + represent a valid result from the database. + """ + if not results: + raise e.InternalError("got no result from the query") + + for res in results: + status = res.status + if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY: + self._raise_for_result(res) + + def _raise_for_result(self, result: "PGresult") -> NoReturn: + """ + Raise an appropriate error message for an unexpected database result + """ + status = result.status + if status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + elif status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: + raise e.ProgrammingError( + "COPY cannot be used with this method; use copy() instead" + ) + else: + raise e.InternalError( + "unexpected result status from query:" f" {pq.ExecStatus(status).name}" + ) + + def _select_current_result( + self, i: int, format: Optional[pq.Format] = None + ) -> None: + """ + Select one of the results in the cursor as the active one. + """ + self._iresult = i + res = self.pgresult = self._results[i] + + # Note: the only reason to override format is to correctly set + # binary loaders on server-side cursors, because send_describe_portal + # only returns a text result. + self._tx.set_pgresult(res, format=format) + + self._pos = 0 + + if res.status == TUPLES_OK: + self._rowcount = self.pgresult.ntuples + + # COPY_OUT has never info about nrows. We need such result for the + # columns in order to return a `description`, but not overwrite the + # cursor rowcount (which was set by the Copy object). + elif res.status != COPY_OUT: + nrows = self.pgresult.command_tuples + self._rowcount = nrows if nrows is not None else -1 + + self._make_row = self._make_row_maker() + + def _set_results_from_pipeline(self, results: List["PGresult"]) -> None: + self._check_results(results) + first_batch = not self._results + + if self._execmany_returning is None: + # Received from execute() + self._results.extend(results) + if first_batch: + self._select_current_result(0) + + else: + # Received from executemany() + if self._execmany_returning: + self._results.extend(results) + if first_batch: + self._select_current_result(0) + self._rowcount = 0 + + # Override rowcount for the first result. Calls to nextset() will + # change it to the value of that result only, but we hope nobody + # will notice. + # You haven't read this comment. + if self._rowcount < 0: + self._rowcount = 0 + for res in results: + self._rowcount += res.command_tuples or 0 + + def _send_prepare(self, name: bytes, query: PostgresQuery) -> None: + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_prepare, + name, + query.query, + param_types=query.types, + ) + ) + self._conn._pipeline.result_queue.append(None) + else: + self._pgconn.send_prepare(name, query.query, param_types=query.types) + + def _send_query_prepared( + self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None + ) -> None: + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_prepared, + name, + pgq.params, + param_formats=pgq.formats, + result_format=fmt, + ) + ) + else: + self._pgconn.send_query_prepared( + name, pgq.params, param_formats=pgq.formats, result_format=fmt + ) + + def _check_result_for_fetch(self) -> None: + if self.closed: + raise e.InterfaceError("the cursor is closed") + res = self.pgresult + if not res: + raise e.ProgrammingError("no result available") + + status = res.status + if status == TUPLES_OK: + return + elif status == FATAL_ERROR: + raise e.error_from_result(res, encoding=self._encoding) + elif status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + else: + raise e.ProgrammingError("the last operation didn't produce a result") + + def _check_copy_result(self, result: "PGresult") -> None: + """ + Check that the value returned in a copy() operation is a legit COPY. + """ + status = result.status + if status == COPY_IN or status == COPY_OUT: + return + elif status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + else: + raise e.ProgrammingError( + "copy() should be used only with COPY ... TO STDOUT or COPY ..." + f" FROM STDIN statements, got {pq.ExecStatus(status).name}" + ) + + def _scroll(self, value: int, mode: str) -> None: + self._check_result_for_fetch() + assert self.pgresult + if mode == "relative": + newpos = self._pos + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + if not 0 <= newpos < self.pgresult.ntuples: + raise IndexError("position out of bound") + self._pos = newpos + + def _close(self) -> None: + """Non-blocking part of closing. Common to sync/async.""" + # Don't reset the query because it may be useful to investigate after + # an error. + self._reset(reset_query=False) + self._closed = True + + @property + def _encoding(self) -> str: + return pgconn_encoding(self._pgconn) + + +class Cursor(BaseCursor["Connection[Any]", Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="Cursor[Any]") + + @overload + def __init__(self: "Cursor[Row]", connection: "Connection[Row]"): + ... + + @overload + def __init__( + self: "Cursor[Row]", + connection: "Connection[Any]", + *, + row_factory: RowFactory[Row], + ): + ... + + def __init__( + self, + connection: "Connection[Any]", + *, + row_factory: Optional[RowFactory[Row]] = None, + ): + super().__init__(connection) + self._row_factory = row_factory or connection.row_factory + + def __enter__(self: _Self) -> _Self: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + self._close() + + @property + def row_factory(self) -> RowFactory[Row]: + """Writable attribute to control how result rows are formed.""" + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: RowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + + def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> _Self: + """ + Execute a query or command to the database. + """ + try: + with self._conn.lock: + self._conn.wait( + self._execute_gen(query, params, prepare=prepare, binary=binary) + ) + except e.Error as ex: + raise ex.with_traceback(None) + return self + + def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = False, + ) -> None: + """ + Execute the same command with a sequence of input data. + """ + try: + if Pipeline.is_supported(): + # If there is already a pipeline, ride it, in order to avoid + # sending unnecessary Sync. + with self._conn.lock: + p = self._conn._pipeline + if p: + self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + # Otherwise, make a new one + if not p: + with self._conn.pipeline(), self._conn.lock: + self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + else: + with self._conn.lock: + self._conn.wait( + self._executemany_gen_no_pipeline(query, params_seq, returning) + ) + except e.Error as ex: + raise ex.with_traceback(None) + + def stream( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> Iterator[Row]: + """ + Iterate row-by-row on a result from the database. + """ + if self._pgconn.pipeline_status: + raise e.ProgrammingError("stream() cannot be used in pipeline mode") + + with self._conn.lock: + + try: + self._conn.wait(self._stream_send_gen(query, params, binary=binary)) + first = True + while self._conn.wait(self._stream_fetchone_gen(first)): + # We know that, if we got a result, it has a single row. + rec: Row = self._tx.load_row(0, self._make_row) # type: ignore + yield rec + first = False + + except e.Error as ex: + raise ex.with_traceback(None) + + finally: + if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results + # already received. + self._conn.cancel() + try: + while self._conn.wait(self._stream_fetchone_gen(first=False)): + pass + except Exception: + pass + + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass + + def fetchone(self) -> Optional[Row]: + """ + Return the next record from the current recordset. + + Return `!None` the recordset is finished. + + :rtype: Optional[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + record = self._tx.load_row(self._pos, self._make_row) + if record is not None: + self._pos += 1 + return record + + def fetchmany(self, size: int = 0) -> List[Row]: + """ + Return the next `!size` records from the current recordset. + + `!size` default to `!self.arraysize` if not specified. + + :rtype: Sequence[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + + if not size: + size = self.arraysize + records = self._tx.load_rows( + self._pos, + min(self._pos + size, self.pgresult.ntuples), + self._make_row, + ) + self._pos += len(records) + return records + + def fetchall(self) -> List[Row]: + """ + Return all the remaining records from the current recordset. + + :rtype: Sequence[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) + self._pos = self.pgresult.ntuples + return records + + def __iter__(self) -> Iterator[Row]: + self._fetch_pipeline() + self._check_result_for_fetch() + + def load(pos: int) -> Optional[Row]: + return self._tx.load_row(pos, self._make_row) + + while True: + row = load(self._pos) + if row is None: + break + self._pos += 1 + yield row + + def scroll(self, value: int, mode: str = "relative") -> None: + """ + Move the cursor in the result set to a new position according to mode. + + If `!mode` is ``'relative'`` (default), `!value` is taken as offset to + the current position in the result set; if set to ``'absolute'``, + `!value` states an absolute target position. + + Raise `!IndexError` in case a scroll operation would leave the result + set. In this case the position will not change. + """ + self._fetch_pipeline() + self._scroll(value, mode) + + @contextmanager + def copy( + self, + statement: Query, + params: Optional[Params] = None, + *, + writer: Optional[CopyWriter] = None, + ) -> Iterator[Copy]: + """ + Initiate a :sql:`COPY` operation and return an object to manage it. + + :rtype: Copy + """ + try: + with self._conn.lock: + self._conn.wait(self._start_copy_gen(statement, params)) + + with Copy(self, writer=writer) as copy: + yield copy + except e.Error as ex: + raise ex.with_traceback(None) + + # If a fresher result has been set on the cursor by the Copy object, + # read its properties (especially rowcount). + self._select_current_result(0) + + def _fetch_pipeline(self) -> None: + if ( + self._execmany_returning is not False + and not self.pgresult + and self._conn._pipeline + ): + with self._conn.lock: + self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py new file mode 100644 index 0000000..8971d40 --- /dev/null +++ b/psycopg/psycopg/cursor_async.py @@ -0,0 +1,250 @@ +""" +psycopg async cursor objects +""" + +# Copyright (C) 2020 The Psycopg Team + +from types import TracebackType +from typing import Any, AsyncIterator, Iterable, List +from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload +from contextlib import asynccontextmanager + +from . import pq +from . import errors as e +from .abc import Query, Params +from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter +from .rows import Row, RowMaker, AsyncRowFactory +from .cursor import BaseCursor +from ._pipeline import Pipeline + +if TYPE_CHECKING: + from .connection_async import AsyncConnection + +ACTIVE = pq.TransactionStatus.ACTIVE + + +class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="AsyncCursor[Any]") + + @overload + def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"): + ... + + @overload + def __init__( + self: "AsyncCursor[Row]", + connection: "AsyncConnection[Any]", + *, + row_factory: AsyncRowFactory[Row], + ): + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, + ): + super().__init__(connection) + self._row_factory = row_factory or connection.row_factory + + async def __aenter__(self: _Self) -> _Self: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + async def close(self) -> None: + self._close() + + @property + def row_factory(self) -> AsyncRowFactory[Row]: + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + + async def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> _Self: + try: + async with self._conn.lock: + await self._conn.wait( + self._execute_gen(query, params, prepare=prepare, binary=binary) + ) + except e.Error as ex: + raise ex.with_traceback(None) + return self + + async def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = False, + ) -> None: + try: + if Pipeline.is_supported(): + # If there is already a pipeline, ride it, in order to avoid + # sending unnecessary Sync. + async with self._conn.lock: + p = self._conn._pipeline + if p: + await self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + # Otherwise, make a new one + if not p: + async with self._conn.pipeline(), self._conn.lock: + await self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + else: + await self._conn.wait( + self._executemany_gen_no_pipeline(query, params_seq, returning) + ) + except e.Error as ex: + raise ex.with_traceback(None) + + async def stream( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> AsyncIterator[Row]: + if self._pgconn.pipeline_status: + raise e.ProgrammingError("stream() cannot be used in pipeline mode") + + async with self._conn.lock: + + try: + await self._conn.wait( + self._stream_send_gen(query, params, binary=binary) + ) + first = True + while await self._conn.wait(self._stream_fetchone_gen(first)): + # We know that, if we got a result, it has a single row. + rec: Row = self._tx.load_row(0, self._make_row) # type: ignore + yield rec + first = False + + except e.Error as ex: + raise ex.with_traceback(None) + + finally: + if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results + # already received. + self._conn.cancel() + try: + while await self._conn.wait( + self._stream_fetchone_gen(first=False) + ): + pass + except Exception: + pass + + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + await self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass + + async def fetchone(self) -> Optional[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + rv = self._tx.load_row(self._pos, self._make_row) + if rv is not None: + self._pos += 1 + return rv + + async def fetchmany(self, size: int = 0) -> List[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + + if not size: + size = self.arraysize + records = self._tx.load_rows( + self._pos, + min(self._pos + size, self.pgresult.ntuples), + self._make_row, + ) + self._pos += len(records) + return records + + async def fetchall(self) -> List[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) + self._pos = self.pgresult.ntuples + return records + + async def __aiter__(self) -> AsyncIterator[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + + def load(pos: int) -> Optional[Row]: + return self._tx.load_row(pos, self._make_row) + + while True: + row = load(self._pos) + if row is None: + break + self._pos += 1 + yield row + + async def scroll(self, value: int, mode: str = "relative") -> None: + self._scroll(value, mode) + + @asynccontextmanager + async def copy( + self, + statement: Query, + params: Optional[Params] = None, + *, + writer: Optional[AsyncCopyWriter] = None, + ) -> AsyncIterator[AsyncCopy]: + """ + :rtype: AsyncCopy + """ + try: + async with self._conn.lock: + await self._conn.wait(self._start_copy_gen(statement, params)) + + async with AsyncCopy(self, writer=writer) as copy: + yield copy + except e.Error as ex: + raise ex.with_traceback(None) + + self._select_current_result(0) + + async def _fetch_pipeline(self) -> None: + if ( + self._execmany_returning is not False + and not self.pgresult + and self._conn._pipeline + ): + async with self._conn.lock: + await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py new file mode 100644 index 0000000..3c3d8b7 --- /dev/null +++ b/psycopg/psycopg/dbapi20.py @@ -0,0 +1,112 @@ +""" +Compatibility objects with DBAPI 2.0 +""" + +# Copyright (C) 2020 The Psycopg Team + +import time +import datetime as dt +from math import floor +from typing import Any, Sequence, Union + +from . import postgres +from .abc import AdaptContext, Buffer +from .types.string import BytesDumper, BytesBinaryDumper + + +class DBAPITypeObject: + def __init__(self, name: str, type_names: Sequence[str]): + self.name = name + self.values = tuple(postgres.types[n].oid for n in type_names) + + def __repr__(self) -> str: + return f"psycopg.{self.name}" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, int): + return other in self.values + else: + return NotImplemented + + def __ne__(self, other: Any) -> bool: + if isinstance(other, int): + return other not in self.values + else: + return NotImplemented + + +BINARY = DBAPITypeObject("BINARY", ("bytea",)) +DATETIME = DBAPITypeObject( + "DATETIME", "timestamp timestamptz date time timetz interval".split() +) +NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split()) +ROWID = DBAPITypeObject("ROWID", ("oid",)) +STRING = DBAPITypeObject("STRING", "text varchar bpchar".split()) + + +class Binary: + def __init__(self, obj: Any): + self.obj = obj + + def __repr__(self) -> str: + sobj = repr(self.obj) + if len(sobj) > 40: + sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)" + return f"{self.__class__.__name__}({sobj})" + + +class BinaryBinaryDumper(BytesBinaryDumper): + def dump(self, obj: Union[Buffer, Binary]) -> Buffer: + if isinstance(obj, Binary): + return super().dump(obj.obj) + else: + return super().dump(obj) + + +class BinaryTextDumper(BytesDumper): + def dump(self, obj: Union[Buffer, Binary]) -> Buffer: + if isinstance(obj, Binary): + return super().dump(obj.obj) + else: + return super().dump(obj) + + +def Date(year: int, month: int, day: int) -> dt.date: + return dt.date(year, month, day) + + +def DateFromTicks(ticks: float) -> dt.date: + return TimestampFromTicks(ticks).date() + + +def Time(hour: int, minute: int, second: int) -> dt.time: + return dt.time(hour, minute, second) + + +def TimeFromTicks(ticks: float) -> dt.time: + return TimestampFromTicks(ticks).time() + + +def Timestamp( + year: int, month: int, day: int, hour: int, minute: int, second: int +) -> dt.datetime: + return dt.datetime(year, month, day, hour, minute, second) + + +def TimestampFromTicks(ticks: float) -> dt.datetime: + secs = floor(ticks) + frac = ticks - secs + t = time.localtime(ticks) + tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff)) + rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo) + return rv + + +def register_dbapi20_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Binary, BinaryTextDumper) + adapters.register_dumper(Binary, BinaryBinaryDumper) + + # Make them also the default dumpers when dumping by bytea oid + adapters.register_dumper(None, BinaryTextDumper) + adapters.register_dumper(None, BinaryBinaryDumper) diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py new file mode 100644 index 0000000..e176954 --- /dev/null +++ b/psycopg/psycopg/errors.py @@ -0,0 +1,1535 @@ +""" +psycopg exceptions + +DBAPI-defined Exceptions are defined in the following hierarchy:: + + Exceptions + |__Warning + |__Error + |__InterfaceError + |__DatabaseError + |__DataError + |__OperationalError + |__IntegrityError + |__InternalError + |__ProgrammingError + |__NotSupportedError +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing_extensions import TypeAlias + +from .pq.abc import PGconn, PGresult +from .pq._enums import DiagnosticField +from ._compat import TypeGuard + +ErrorInfo: TypeAlias = Union[None, PGresult, Dict[int, Optional[bytes]]] + +_sqlcodes: Dict[str, "Type[Error]"] = {} + + +class Warning(Exception): + """ + Exception raised for important warnings. + + Defined for DBAPI compatibility, but never raised by ``psycopg``. + """ + + __module__ = "psycopg" + + +class Error(Exception): + """ + Base exception for all the errors psycopg will raise. + + Exception that is the base class of all other error exceptions. You can + use this to catch all errors with one single `!except` statement. + + This exception is guaranteed to be picklable. + """ + + __module__ = "psycopg" + + sqlstate: Optional[str] = None + + def __init__( + self, + *args: Sequence[Any], + info: ErrorInfo = None, + encoding: str = "utf-8", + pgconn: Optional[PGconn] = None + ): + super().__init__(*args) + self._info = info + self._encoding = encoding + self._pgconn = pgconn + + # Handle sqlstate codes for which we don't have a class. + if not self.sqlstate and info: + self.sqlstate = self.diag.sqlstate + + @property + def pgconn(self) -> Optional[PGconn]: + """The connection object, if the error was raised from a connection attempt. + + :rtype: Optional[psycopg.pq.PGconn] + """ + return self._pgconn if self._pgconn else None + + @property + def pgresult(self) -> Optional[PGresult]: + """The result object, if the exception was raised after a failed query. + + :rtype: Optional[psycopg.pq.PGresult] + """ + return self._info if _is_pgresult(self._info) else None + + @property + def diag(self) -> "Diagnostic": + """ + A `Diagnostic` object to inspect details of the errors from the database. + """ + return Diagnostic(self._info, encoding=self._encoding) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + # To make the exception picklable + res[2]["_info"] = _info_to_dict(self._info) + res[2]["_pgconn"] = None + + return res + + +class InterfaceError(Error): + """ + An error related to the database interface rather than the database itself. + """ + + __module__ = "psycopg" + + +class DatabaseError(Error): + """ + Exception raised for errors that are related to the database. + """ + + __module__ = "psycopg" + + def __init_subclass__(cls, code: Optional[str] = None, name: Optional[str] = None): + if code: + _sqlcodes[code] = cls + cls.sqlstate = code + if name: + _sqlcodes[name] = cls + + +class DataError(DatabaseError): + """ + An error caused by problems with the processed data. + + Examples may be division by zero, numeric value out of range, etc. + """ + + __module__ = "psycopg" + + +class OperationalError(DatabaseError): + """ + An error related to the database's operation. + + These errors are not necessarily under the control of the programmer, e.g. + an unexpected disconnect occurs, the data source name is not found, a + transaction could not be processed, a memory allocation error occurred + during processing, etc. + """ + + __module__ = "psycopg" + + +class IntegrityError(DatabaseError): + """ + An error caused when the relational integrity of the database is affected. + + An example may be a foreign key check failed. + """ + + __module__ = "psycopg" + + +class InternalError(DatabaseError): + """ + An error generated when the database encounters an internal error, + + Examples could be the cursor is not valid anymore, the transaction is out + of sync, etc. + """ + + __module__ = "psycopg" + + +class ProgrammingError(DatabaseError): + """ + Exception raised for programming errors + + Examples may be table not found or already exists, syntax error in the SQL + statement, wrong number of parameters specified, etc. + """ + + __module__ = "psycopg" + + +class NotSupportedError(DatabaseError): + """ + A method or database API was used which is not supported by the database. + """ + + __module__ = "psycopg" + + +class ConnectionTimeout(OperationalError): + """ + Exception raised on timeout of the `~psycopg.Connection.connect()` method. + + The error is raised if the ``connect_timeout`` is specified and a + connection is not obtained in useful time. + + Subclass of `~psycopg.OperationalError`. + """ + + +class PipelineAborted(OperationalError): + """ + Raised when a operation fails because the current pipeline is in aborted state. + + Subclass of `~psycopg.OperationalError`. + """ + + +class Diagnostic: + """Details from a database error report.""" + + def __init__(self, info: ErrorInfo, encoding: str = "utf-8"): + self._info = info + self._encoding = encoding + + @property + def severity(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY) + + @property + def severity_nonlocalized(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED) + + @property + def sqlstate(self) -> Optional[str]: + return self._error_message(DiagnosticField.SQLSTATE) + + @property + def message_primary(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_PRIMARY) + + @property + def message_detail(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_DETAIL) + + @property + def message_hint(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_HINT) + + @property + def statement_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.STATEMENT_POSITION) + + @property + def internal_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_POSITION) + + @property + def internal_query(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_QUERY) + + @property + def context(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONTEXT) + + @property + def schema_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.SCHEMA_NAME) + + @property + def table_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.TABLE_NAME) + + @property + def column_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.COLUMN_NAME) + + @property + def datatype_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.DATATYPE_NAME) + + @property + def constraint_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONSTRAINT_NAME) + + @property + def source_file(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FILE) + + @property + def source_line(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_LINE) + + @property + def source_function(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FUNCTION) + + def _error_message(self, field: DiagnosticField) -> Optional[str]: + if self._info: + if isinstance(self._info, dict): + val = self._info.get(field) + else: + val = self._info.error_field(field) + + if val is not None: + return val.decode(self._encoding, "replace") + + return None + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + res[2]["_info"] = _info_to_dict(self._info) + + return res + + +def _info_to_dict(info: ErrorInfo) -> ErrorInfo: + """ + Convert a PGresult to a dictionary to make the info picklable. + """ + # PGresult is a protocol, can't use isinstance + if _is_pgresult(info): + return {v: info.error_field(v) for v in DiagnosticField} + else: + return info + + +def lookup(sqlstate: str) -> Type[Error]: + """Lookup an error code or `constant name`__ and return its exception class. + + Raise `!KeyError` if the code is not found. + + .. __: https://www.postgresql.org/docs/current/errcodes-appendix.html + #ERRCODES-TABLE + """ + return _sqlcodes[sqlstate.upper()] + + +def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error: + from psycopg import pq + + state = result.error_field(DiagnosticField.SQLSTATE) or b"" + cls = _class_for_state(state.decode("ascii")) + return cls( + pq.error_message(result, encoding=encoding), + info=result, + encoding=encoding, + ) + + +def _is_pgresult(info: ErrorInfo) -> TypeGuard[PGresult]: + """Return True if an ErrorInfo is a PGresult instance.""" + # PGresult is a protocol, can't use isinstance + return hasattr(info, "error_field") + + +def _class_for_state(sqlstate: str) -> Type[Error]: + try: + return lookup(sqlstate) + except KeyError: + return get_base_exception(sqlstate) + + +def get_base_exception(sqlstate: str) -> Type[Error]: + return ( + _base_exc_map.get(sqlstate[:2]) + or _base_exc_map.get(sqlstate[:1]) + or DatabaseError + ) + + +_base_exc_map = { + "08": OperationalError, # Connection Exception + "0A": NotSupportedError, # Feature Not Supported + "20": ProgrammingError, # Case Not Foud + "21": ProgrammingError, # Cardinality Violation + "22": DataError, # Data Exception + "23": IntegrityError, # Integrity Constraint Violation + "24": InternalError, # Invalid Cursor State + "25": InternalError, # Invalid Transaction State + "26": ProgrammingError, # Invalid SQL Statement Name * + "27": OperationalError, # Triggered Data Change Violation + "28": OperationalError, # Invalid Authorization Specification + "2B": InternalError, # Dependent Privilege Descriptors Still Exist + "2D": InternalError, # Invalid Transaction Termination + "2F": OperationalError, # SQL Routine Exception * + "34": ProgrammingError, # Invalid Cursor Name * + "38": OperationalError, # External Routine Exception * + "39": OperationalError, # External Routine Invocation Exception * + "3B": OperationalError, # Savepoint Exception * + "3D": ProgrammingError, # Invalid Catalog Name + "3F": ProgrammingError, # Invalid Schema Name + "40": OperationalError, # Transaction Rollback + "42": ProgrammingError, # Syntax Error or Access Rule Violation + "44": ProgrammingError, # WITH CHECK OPTION Violation + "53": OperationalError, # Insufficient Resources + "54": OperationalError, # Program Limit Exceeded + "55": OperationalError, # Object Not In Prerequisite State + "57": OperationalError, # Operator Intervention + "58": OperationalError, # System Error (errors external to PostgreSQL itself) + "F": OperationalError, # Configuration File Error + "H": OperationalError, # Foreign Data Wrapper Error (SQL/MED) + "P": ProgrammingError, # PL/pgSQL Error + "X": InternalError, # Internal Error +} + + +# Error classes generated by tools/update_errors.py + +# fmt: off +# autogenerated: start + + +# Class 02 - No Data (this is also a warning class per the SQL standard) + +class NoData(DatabaseError, + code='02000', name='NO_DATA'): + pass + +class NoAdditionalDynamicResultSetsReturned(DatabaseError, + code='02001', name='NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED'): + pass + + +# Class 03 - SQL Statement Not Yet Complete + +class SqlStatementNotYetComplete(DatabaseError, + code='03000', name='SQL_STATEMENT_NOT_YET_COMPLETE'): + pass + + +# Class 08 - Connection Exception + +class ConnectionException(OperationalError, + code='08000', name='CONNECTION_EXCEPTION'): + pass + +class SqlclientUnableToEstablishSqlconnection(OperationalError, + code='08001', name='SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION'): + pass + +class ConnectionDoesNotExist(OperationalError, + code='08003', name='CONNECTION_DOES_NOT_EXIST'): + pass + +class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError, + code='08004', name='SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION'): + pass + +class ConnectionFailure(OperationalError, + code='08006', name='CONNECTION_FAILURE'): + pass + +class TransactionResolutionUnknown(OperationalError, + code='08007', name='TRANSACTION_RESOLUTION_UNKNOWN'): + pass + +class ProtocolViolation(OperationalError, + code='08P01', name='PROTOCOL_VIOLATION'): + pass + + +# Class 09 - Triggered Action Exception + +class TriggeredActionException(DatabaseError, + code='09000', name='TRIGGERED_ACTION_EXCEPTION'): + pass + + +# Class 0A - Feature Not Supported + +class FeatureNotSupported(NotSupportedError, + code='0A000', name='FEATURE_NOT_SUPPORTED'): + pass + + +# Class 0B - Invalid Transaction Initiation + +class InvalidTransactionInitiation(DatabaseError, + code='0B000', name='INVALID_TRANSACTION_INITIATION'): + pass + + +# Class 0F - Locator Exception + +class LocatorException(DatabaseError, + code='0F000', name='LOCATOR_EXCEPTION'): + pass + +class InvalidLocatorSpecification(DatabaseError, + code='0F001', name='INVALID_LOCATOR_SPECIFICATION'): + pass + + +# Class 0L - Invalid Grantor + +class InvalidGrantor(DatabaseError, + code='0L000', name='INVALID_GRANTOR'): + pass + +class InvalidGrantOperation(DatabaseError, + code='0LP01', name='INVALID_GRANT_OPERATION'): + pass + + +# Class 0P - Invalid Role Specification + +class InvalidRoleSpecification(DatabaseError, + code='0P000', name='INVALID_ROLE_SPECIFICATION'): + pass + + +# Class 0Z - Diagnostics Exception + +class DiagnosticsException(DatabaseError, + code='0Z000', name='DIAGNOSTICS_EXCEPTION'): + pass + +class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError, + code='0Z002', name='STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER'): + pass + + +# Class 20 - Case Not Found + +class CaseNotFound(ProgrammingError, + code='20000', name='CASE_NOT_FOUND'): + pass + + +# Class 21 - Cardinality Violation + +class CardinalityViolation(ProgrammingError, + code='21000', name='CARDINALITY_VIOLATION'): + pass + + +# Class 22 - Data Exception + +class DataException(DataError, + code='22000', name='DATA_EXCEPTION'): + pass + +class StringDataRightTruncation(DataError, + code='22001', name='STRING_DATA_RIGHT_TRUNCATION'): + pass + +class NullValueNoIndicatorParameter(DataError, + code='22002', name='NULL_VALUE_NO_INDICATOR_PARAMETER'): + pass + +class NumericValueOutOfRange(DataError, + code='22003', name='NUMERIC_VALUE_OUT_OF_RANGE'): + pass + +class NullValueNotAllowed(DataError, + code='22004', name='NULL_VALUE_NOT_ALLOWED'): + pass + +class ErrorInAssignment(DataError, + code='22005', name='ERROR_IN_ASSIGNMENT'): + pass + +class InvalidDatetimeFormat(DataError, + code='22007', name='INVALID_DATETIME_FORMAT'): + pass + +class DatetimeFieldOverflow(DataError, + code='22008', name='DATETIME_FIELD_OVERFLOW'): + pass + +class InvalidTimeZoneDisplacementValue(DataError, + code='22009', name='INVALID_TIME_ZONE_DISPLACEMENT_VALUE'): + pass + +class EscapeCharacterConflict(DataError, + code='2200B', name='ESCAPE_CHARACTER_CONFLICT'): + pass + +class InvalidUseOfEscapeCharacter(DataError, + code='2200C', name='INVALID_USE_OF_ESCAPE_CHARACTER'): + pass + +class InvalidEscapeOctet(DataError, + code='2200D', name='INVALID_ESCAPE_OCTET'): + pass + +class ZeroLengthCharacterString(DataError, + code='2200F', name='ZERO_LENGTH_CHARACTER_STRING'): + pass + +class MostSpecificTypeMismatch(DataError, + code='2200G', name='MOST_SPECIFIC_TYPE_MISMATCH'): + pass + +class SequenceGeneratorLimitExceeded(DataError, + code='2200H', name='SEQUENCE_GENERATOR_LIMIT_EXCEEDED'): + pass + +class NotAnXmlDocument(DataError, + code='2200L', name='NOT_AN_XML_DOCUMENT'): + pass + +class InvalidXmlDocument(DataError, + code='2200M', name='INVALID_XML_DOCUMENT'): + pass + +class InvalidXmlContent(DataError, + code='2200N', name='INVALID_XML_CONTENT'): + pass + +class InvalidXmlComment(DataError, + code='2200S', name='INVALID_XML_COMMENT'): + pass + +class InvalidXmlProcessingInstruction(DataError, + code='2200T', name='INVALID_XML_PROCESSING_INSTRUCTION'): + pass + +class InvalidIndicatorParameterValue(DataError, + code='22010', name='INVALID_INDICATOR_PARAMETER_VALUE'): + pass + +class SubstringError(DataError, + code='22011', name='SUBSTRING_ERROR'): + pass + +class DivisionByZero(DataError, + code='22012', name='DIVISION_BY_ZERO'): + pass + +class InvalidPrecedingOrFollowingSize(DataError, + code='22013', name='INVALID_PRECEDING_OR_FOLLOWING_SIZE'): + pass + +class InvalidArgumentForNtileFunction(DataError, + code='22014', name='INVALID_ARGUMENT_FOR_NTILE_FUNCTION'): + pass + +class IntervalFieldOverflow(DataError, + code='22015', name='INTERVAL_FIELD_OVERFLOW'): + pass + +class InvalidArgumentForNthValueFunction(DataError, + code='22016', name='INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION'): + pass + +class InvalidCharacterValueForCast(DataError, + code='22018', name='INVALID_CHARACTER_VALUE_FOR_CAST'): + pass + +class InvalidEscapeCharacter(DataError, + code='22019', name='INVALID_ESCAPE_CHARACTER'): + pass + +class InvalidRegularExpression(DataError, + code='2201B', name='INVALID_REGULAR_EXPRESSION'): + pass + +class InvalidArgumentForLogarithm(DataError, + code='2201E', name='INVALID_ARGUMENT_FOR_LOGARITHM'): + pass + +class InvalidArgumentForPowerFunction(DataError, + code='2201F', name='INVALID_ARGUMENT_FOR_POWER_FUNCTION'): + pass + +class InvalidArgumentForWidthBucketFunction(DataError, + code='2201G', name='INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION'): + pass + +class InvalidRowCountInLimitClause(DataError, + code='2201W', name='INVALID_ROW_COUNT_IN_LIMIT_CLAUSE'): + pass + +class InvalidRowCountInResultOffsetClause(DataError, + code='2201X', name='INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE'): + pass + +class CharacterNotInRepertoire(DataError, + code='22021', name='CHARACTER_NOT_IN_REPERTOIRE'): + pass + +class IndicatorOverflow(DataError, + code='22022', name='INDICATOR_OVERFLOW'): + pass + +class InvalidParameterValue(DataError, + code='22023', name='INVALID_PARAMETER_VALUE'): + pass + +class UnterminatedCString(DataError, + code='22024', name='UNTERMINATED_C_STRING'): + pass + +class InvalidEscapeSequence(DataError, + code='22025', name='INVALID_ESCAPE_SEQUENCE'): + pass + +class StringDataLengthMismatch(DataError, + code='22026', name='STRING_DATA_LENGTH_MISMATCH'): + pass + +class TrimError(DataError, + code='22027', name='TRIM_ERROR'): + pass + +class ArraySubscriptError(DataError, + code='2202E', name='ARRAY_SUBSCRIPT_ERROR'): + pass + +class InvalidTablesampleRepeat(DataError, + code='2202G', name='INVALID_TABLESAMPLE_REPEAT'): + pass + +class InvalidTablesampleArgument(DataError, + code='2202H', name='INVALID_TABLESAMPLE_ARGUMENT'): + pass + +class DuplicateJsonObjectKeyValue(DataError, + code='22030', name='DUPLICATE_JSON_OBJECT_KEY_VALUE'): + pass + +class InvalidArgumentForSqlJsonDatetimeFunction(DataError, + code='22031', name='INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION'): + pass + +class InvalidJsonText(DataError, + code='22032', name='INVALID_JSON_TEXT'): + pass + +class InvalidSqlJsonSubscript(DataError, + code='22033', name='INVALID_SQL_JSON_SUBSCRIPT'): + pass + +class MoreThanOneSqlJsonItem(DataError, + code='22034', name='MORE_THAN_ONE_SQL_JSON_ITEM'): + pass + +class NoSqlJsonItem(DataError, + code='22035', name='NO_SQL_JSON_ITEM'): + pass + +class NonNumericSqlJsonItem(DataError, + code='22036', name='NON_NUMERIC_SQL_JSON_ITEM'): + pass + +class NonUniqueKeysInAJsonObject(DataError, + code='22037', name='NON_UNIQUE_KEYS_IN_A_JSON_OBJECT'): + pass + +class SingletonSqlJsonItemRequired(DataError, + code='22038', name='SINGLETON_SQL_JSON_ITEM_REQUIRED'): + pass + +class SqlJsonArrayNotFound(DataError, + code='22039', name='SQL_JSON_ARRAY_NOT_FOUND'): + pass + +class SqlJsonMemberNotFound(DataError, + code='2203A', name='SQL_JSON_MEMBER_NOT_FOUND'): + pass + +class SqlJsonNumberNotFound(DataError, + code='2203B', name='SQL_JSON_NUMBER_NOT_FOUND'): + pass + +class SqlJsonObjectNotFound(DataError, + code='2203C', name='SQL_JSON_OBJECT_NOT_FOUND'): + pass + +class TooManyJsonArrayElements(DataError, + code='2203D', name='TOO_MANY_JSON_ARRAY_ELEMENTS'): + pass + +class TooManyJsonObjectMembers(DataError, + code='2203E', name='TOO_MANY_JSON_OBJECT_MEMBERS'): + pass + +class SqlJsonScalarRequired(DataError, + code='2203F', name='SQL_JSON_SCALAR_REQUIRED'): + pass + +class SqlJsonItemCannotBeCastToTargetType(DataError, + code='2203G', name='SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE'): + pass + +class FloatingPointException(DataError, + code='22P01', name='FLOATING_POINT_EXCEPTION'): + pass + +class InvalidTextRepresentation(DataError, + code='22P02', name='INVALID_TEXT_REPRESENTATION'): + pass + +class InvalidBinaryRepresentation(DataError, + code='22P03', name='INVALID_BINARY_REPRESENTATION'): + pass + +class BadCopyFileFormat(DataError, + code='22P04', name='BAD_COPY_FILE_FORMAT'): + pass + +class UntranslatableCharacter(DataError, + code='22P05', name='UNTRANSLATABLE_CHARACTER'): + pass + +class NonstandardUseOfEscapeCharacter(DataError, + code='22P06', name='NONSTANDARD_USE_OF_ESCAPE_CHARACTER'): + pass + + +# Class 23 - Integrity Constraint Violation + +class IntegrityConstraintViolation(IntegrityError, + code='23000', name='INTEGRITY_CONSTRAINT_VIOLATION'): + pass + +class RestrictViolation(IntegrityError, + code='23001', name='RESTRICT_VIOLATION'): + pass + +class NotNullViolation(IntegrityError, + code='23502', name='NOT_NULL_VIOLATION'): + pass + +class ForeignKeyViolation(IntegrityError, + code='23503', name='FOREIGN_KEY_VIOLATION'): + pass + +class UniqueViolation(IntegrityError, + code='23505', name='UNIQUE_VIOLATION'): + pass + +class CheckViolation(IntegrityError, + code='23514', name='CHECK_VIOLATION'): + pass + +class ExclusionViolation(IntegrityError, + code='23P01', name='EXCLUSION_VIOLATION'): + pass + + +# Class 24 - Invalid Cursor State + +class InvalidCursorState(InternalError, + code='24000', name='INVALID_CURSOR_STATE'): + pass + + +# Class 25 - Invalid Transaction State + +class InvalidTransactionState(InternalError, + code='25000', name='INVALID_TRANSACTION_STATE'): + pass + +class ActiveSqlTransaction(InternalError, + code='25001', name='ACTIVE_SQL_TRANSACTION'): + pass + +class BranchTransactionAlreadyActive(InternalError, + code='25002', name='BRANCH_TRANSACTION_ALREADY_ACTIVE'): + pass + +class InappropriateAccessModeForBranchTransaction(InternalError, + code='25003', name='INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION'): + pass + +class InappropriateIsolationLevelForBranchTransaction(InternalError, + code='25004', name='INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION'): + pass + +class NoActiveSqlTransactionForBranchTransaction(InternalError, + code='25005', name='NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION'): + pass + +class ReadOnlySqlTransaction(InternalError, + code='25006', name='READ_ONLY_SQL_TRANSACTION'): + pass + +class SchemaAndDataStatementMixingNotSupported(InternalError, + code='25007', name='SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED'): + pass + +class HeldCursorRequiresSameIsolationLevel(InternalError, + code='25008', name='HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL'): + pass + +class NoActiveSqlTransaction(InternalError, + code='25P01', name='NO_ACTIVE_SQL_TRANSACTION'): + pass + +class InFailedSqlTransaction(InternalError, + code='25P02', name='IN_FAILED_SQL_TRANSACTION'): + pass + +class IdleInTransactionSessionTimeout(InternalError, + code='25P03', name='IDLE_IN_TRANSACTION_SESSION_TIMEOUT'): + pass + + +# Class 26 - Invalid SQL Statement Name + +class InvalidSqlStatementName(ProgrammingError, + code='26000', name='INVALID_SQL_STATEMENT_NAME'): + pass + + +# Class 27 - Triggered Data Change Violation + +class TriggeredDataChangeViolation(OperationalError, + code='27000', name='TRIGGERED_DATA_CHANGE_VIOLATION'): + pass + + +# Class 28 - Invalid Authorization Specification + +class InvalidAuthorizationSpecification(OperationalError, + code='28000', name='INVALID_AUTHORIZATION_SPECIFICATION'): + pass + +class InvalidPassword(OperationalError, + code='28P01', name='INVALID_PASSWORD'): + pass + + +# Class 2B - Dependent Privilege Descriptors Still Exist + +class DependentPrivilegeDescriptorsStillExist(InternalError, + code='2B000', name='DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST'): + pass + +class DependentObjectsStillExist(InternalError, + code='2BP01', name='DEPENDENT_OBJECTS_STILL_EXIST'): + pass + + +# Class 2D - Invalid Transaction Termination + +class InvalidTransactionTermination(InternalError, + code='2D000', name='INVALID_TRANSACTION_TERMINATION'): + pass + + +# Class 2F - SQL Routine Exception + +class SqlRoutineException(OperationalError, + code='2F000', name='SQL_ROUTINE_EXCEPTION'): + pass + +class ModifyingSqlDataNotPermitted(OperationalError, + code='2F002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'): + pass + +class ProhibitedSqlStatementAttempted(OperationalError, + code='2F003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'): + pass + +class ReadingSqlDataNotPermitted(OperationalError, + code='2F004', name='READING_SQL_DATA_NOT_PERMITTED'): + pass + +class FunctionExecutedNoReturnStatement(OperationalError, + code='2F005', name='FUNCTION_EXECUTED_NO_RETURN_STATEMENT'): + pass + + +# Class 34 - Invalid Cursor Name + +class InvalidCursorName(ProgrammingError, + code='34000', name='INVALID_CURSOR_NAME'): + pass + + +# Class 38 - External Routine Exception + +class ExternalRoutineException(OperationalError, + code='38000', name='EXTERNAL_ROUTINE_EXCEPTION'): + pass + +class ContainingSqlNotPermitted(OperationalError, + code='38001', name='CONTAINING_SQL_NOT_PERMITTED'): + pass + +class ModifyingSqlDataNotPermittedExt(OperationalError, + code='38002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'): + pass + +class ProhibitedSqlStatementAttemptedExt(OperationalError, + code='38003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'): + pass + +class ReadingSqlDataNotPermittedExt(OperationalError, + code='38004', name='READING_SQL_DATA_NOT_PERMITTED'): + pass + + +# Class 39 - External Routine Invocation Exception + +class ExternalRoutineInvocationException(OperationalError, + code='39000', name='EXTERNAL_ROUTINE_INVOCATION_EXCEPTION'): + pass + +class InvalidSqlstateReturned(OperationalError, + code='39001', name='INVALID_SQLSTATE_RETURNED'): + pass + +class NullValueNotAllowedExt(OperationalError, + code='39004', name='NULL_VALUE_NOT_ALLOWED'): + pass + +class TriggerProtocolViolated(OperationalError, + code='39P01', name='TRIGGER_PROTOCOL_VIOLATED'): + pass + +class SrfProtocolViolated(OperationalError, + code='39P02', name='SRF_PROTOCOL_VIOLATED'): + pass + +class EventTriggerProtocolViolated(OperationalError, + code='39P03', name='EVENT_TRIGGER_PROTOCOL_VIOLATED'): + pass + + +# Class 3B - Savepoint Exception + +class SavepointException(OperationalError, + code='3B000', name='SAVEPOINT_EXCEPTION'): + pass + +class InvalidSavepointSpecification(OperationalError, + code='3B001', name='INVALID_SAVEPOINT_SPECIFICATION'): + pass + + +# Class 3D - Invalid Catalog Name + +class InvalidCatalogName(ProgrammingError, + code='3D000', name='INVALID_CATALOG_NAME'): + pass + + +# Class 3F - Invalid Schema Name + +class InvalidSchemaName(ProgrammingError, + code='3F000', name='INVALID_SCHEMA_NAME'): + pass + + +# Class 40 - Transaction Rollback + +class TransactionRollback(OperationalError, + code='40000', name='TRANSACTION_ROLLBACK'): + pass + +class SerializationFailure(OperationalError, + code='40001', name='SERIALIZATION_FAILURE'): + pass + +class TransactionIntegrityConstraintViolation(OperationalError, + code='40002', name='TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION'): + pass + +class StatementCompletionUnknown(OperationalError, + code='40003', name='STATEMENT_COMPLETION_UNKNOWN'): + pass + +class DeadlockDetected(OperationalError, + code='40P01', name='DEADLOCK_DETECTED'): + pass + + +# Class 42 - Syntax Error or Access Rule Violation + +class SyntaxErrorOrAccessRuleViolation(ProgrammingError, + code='42000', name='SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION'): + pass + +class InsufficientPrivilege(ProgrammingError, + code='42501', name='INSUFFICIENT_PRIVILEGE'): + pass + +class SyntaxError(ProgrammingError, + code='42601', name='SYNTAX_ERROR'): + pass + +class InvalidName(ProgrammingError, + code='42602', name='INVALID_NAME'): + pass + +class InvalidColumnDefinition(ProgrammingError, + code='42611', name='INVALID_COLUMN_DEFINITION'): + pass + +class NameTooLong(ProgrammingError, + code='42622', name='NAME_TOO_LONG'): + pass + +class DuplicateColumn(ProgrammingError, + code='42701', name='DUPLICATE_COLUMN'): + pass + +class AmbiguousColumn(ProgrammingError, + code='42702', name='AMBIGUOUS_COLUMN'): + pass + +class UndefinedColumn(ProgrammingError, + code='42703', name='UNDEFINED_COLUMN'): + pass + +class UndefinedObject(ProgrammingError, + code='42704', name='UNDEFINED_OBJECT'): + pass + +class DuplicateObject(ProgrammingError, + code='42710', name='DUPLICATE_OBJECT'): + pass + +class DuplicateAlias(ProgrammingError, + code='42712', name='DUPLICATE_ALIAS'): + pass + +class DuplicateFunction(ProgrammingError, + code='42723', name='DUPLICATE_FUNCTION'): + pass + +class AmbiguousFunction(ProgrammingError, + code='42725', name='AMBIGUOUS_FUNCTION'): + pass + +class GroupingError(ProgrammingError, + code='42803', name='GROUPING_ERROR'): + pass + +class DatatypeMismatch(ProgrammingError, + code='42804', name='DATATYPE_MISMATCH'): + pass + +class WrongObjectType(ProgrammingError, + code='42809', name='WRONG_OBJECT_TYPE'): + pass + +class InvalidForeignKey(ProgrammingError, + code='42830', name='INVALID_FOREIGN_KEY'): + pass + +class CannotCoerce(ProgrammingError, + code='42846', name='CANNOT_COERCE'): + pass + +class UndefinedFunction(ProgrammingError, + code='42883', name='UNDEFINED_FUNCTION'): + pass + +class GeneratedAlways(ProgrammingError, + code='428C9', name='GENERATED_ALWAYS'): + pass + +class ReservedName(ProgrammingError, + code='42939', name='RESERVED_NAME'): + pass + +class UndefinedTable(ProgrammingError, + code='42P01', name='UNDEFINED_TABLE'): + pass + +class UndefinedParameter(ProgrammingError, + code='42P02', name='UNDEFINED_PARAMETER'): + pass + +class DuplicateCursor(ProgrammingError, + code='42P03', name='DUPLICATE_CURSOR'): + pass + +class DuplicateDatabase(ProgrammingError, + code='42P04', name='DUPLICATE_DATABASE'): + pass + +class DuplicatePreparedStatement(ProgrammingError, + code='42P05', name='DUPLICATE_PREPARED_STATEMENT'): + pass + +class DuplicateSchema(ProgrammingError, + code='42P06', name='DUPLICATE_SCHEMA'): + pass + +class DuplicateTable(ProgrammingError, + code='42P07', name='DUPLICATE_TABLE'): + pass + +class AmbiguousParameter(ProgrammingError, + code='42P08', name='AMBIGUOUS_PARAMETER'): + pass + +class AmbiguousAlias(ProgrammingError, + code='42P09', name='AMBIGUOUS_ALIAS'): + pass + +class InvalidColumnReference(ProgrammingError, + code='42P10', name='INVALID_COLUMN_REFERENCE'): + pass + +class InvalidCursorDefinition(ProgrammingError, + code='42P11', name='INVALID_CURSOR_DEFINITION'): + pass + +class InvalidDatabaseDefinition(ProgrammingError, + code='42P12', name='INVALID_DATABASE_DEFINITION'): + pass + +class InvalidFunctionDefinition(ProgrammingError, + code='42P13', name='INVALID_FUNCTION_DEFINITION'): + pass + +class InvalidPreparedStatementDefinition(ProgrammingError, + code='42P14', name='INVALID_PREPARED_STATEMENT_DEFINITION'): + pass + +class InvalidSchemaDefinition(ProgrammingError, + code='42P15', name='INVALID_SCHEMA_DEFINITION'): + pass + +class InvalidTableDefinition(ProgrammingError, + code='42P16', name='INVALID_TABLE_DEFINITION'): + pass + +class InvalidObjectDefinition(ProgrammingError, + code='42P17', name='INVALID_OBJECT_DEFINITION'): + pass + +class IndeterminateDatatype(ProgrammingError, + code='42P18', name='INDETERMINATE_DATATYPE'): + pass + +class InvalidRecursion(ProgrammingError, + code='42P19', name='INVALID_RECURSION'): + pass + +class WindowingError(ProgrammingError, + code='42P20', name='WINDOWING_ERROR'): + pass + +class CollationMismatch(ProgrammingError, + code='42P21', name='COLLATION_MISMATCH'): + pass + +class IndeterminateCollation(ProgrammingError, + code='42P22', name='INDETERMINATE_COLLATION'): + pass + + +# Class 44 - WITH CHECK OPTION Violation + +class WithCheckOptionViolation(ProgrammingError, + code='44000', name='WITH_CHECK_OPTION_VIOLATION'): + pass + + +# Class 53 - Insufficient Resources + +class InsufficientResources(OperationalError, + code='53000', name='INSUFFICIENT_RESOURCES'): + pass + +class DiskFull(OperationalError, + code='53100', name='DISK_FULL'): + pass + +class OutOfMemory(OperationalError, + code='53200', name='OUT_OF_MEMORY'): + pass + +class TooManyConnections(OperationalError, + code='53300', name='TOO_MANY_CONNECTIONS'): + pass + +class ConfigurationLimitExceeded(OperationalError, + code='53400', name='CONFIGURATION_LIMIT_EXCEEDED'): + pass + + +# Class 54 - Program Limit Exceeded + +class ProgramLimitExceeded(OperationalError, + code='54000', name='PROGRAM_LIMIT_EXCEEDED'): + pass + +class StatementTooComplex(OperationalError, + code='54001', name='STATEMENT_TOO_COMPLEX'): + pass + +class TooManyColumns(OperationalError, + code='54011', name='TOO_MANY_COLUMNS'): + pass + +class TooManyArguments(OperationalError, + code='54023', name='TOO_MANY_ARGUMENTS'): + pass + + +# Class 55 - Object Not In Prerequisite State + +class ObjectNotInPrerequisiteState(OperationalError, + code='55000', name='OBJECT_NOT_IN_PREREQUISITE_STATE'): + pass + +class ObjectInUse(OperationalError, + code='55006', name='OBJECT_IN_USE'): + pass + +class CantChangeRuntimeParam(OperationalError, + code='55P02', name='CANT_CHANGE_RUNTIME_PARAM'): + pass + +class LockNotAvailable(OperationalError, + code='55P03', name='LOCK_NOT_AVAILABLE'): + pass + +class UnsafeNewEnumValueUsage(OperationalError, + code='55P04', name='UNSAFE_NEW_ENUM_VALUE_USAGE'): + pass + + +# Class 57 - Operator Intervention + +class OperatorIntervention(OperationalError, + code='57000', name='OPERATOR_INTERVENTION'): + pass + +class QueryCanceled(OperationalError, + code='57014', name='QUERY_CANCELED'): + pass + +class AdminShutdown(OperationalError, + code='57P01', name='ADMIN_SHUTDOWN'): + pass + +class CrashShutdown(OperationalError, + code='57P02', name='CRASH_SHUTDOWN'): + pass + +class CannotConnectNow(OperationalError, + code='57P03', name='CANNOT_CONNECT_NOW'): + pass + +class DatabaseDropped(OperationalError, + code='57P04', name='DATABASE_DROPPED'): + pass + +class IdleSessionTimeout(OperationalError, + code='57P05', name='IDLE_SESSION_TIMEOUT'): + pass + + +# Class 58 - System Error (errors external to PostgreSQL itself) + +class SystemError(OperationalError, + code='58000', name='SYSTEM_ERROR'): + pass + +class IoError(OperationalError, + code='58030', name='IO_ERROR'): + pass + +class UndefinedFile(OperationalError, + code='58P01', name='UNDEFINED_FILE'): + pass + +class DuplicateFile(OperationalError, + code='58P02', name='DUPLICATE_FILE'): + pass + + +# Class 72 - Snapshot Failure + +class SnapshotTooOld(DatabaseError, + code='72000', name='SNAPSHOT_TOO_OLD'): + pass + + +# Class F0 - Configuration File Error + +class ConfigFileError(OperationalError, + code='F0000', name='CONFIG_FILE_ERROR'): + pass + +class LockFileExists(OperationalError, + code='F0001', name='LOCK_FILE_EXISTS'): + pass + + +# Class HV - Foreign Data Wrapper Error (SQL/MED) + +class FdwError(OperationalError, + code='HV000', name='FDW_ERROR'): + pass + +class FdwOutOfMemory(OperationalError, + code='HV001', name='FDW_OUT_OF_MEMORY'): + pass + +class FdwDynamicParameterValueNeeded(OperationalError, + code='HV002', name='FDW_DYNAMIC_PARAMETER_VALUE_NEEDED'): + pass + +class FdwInvalidDataType(OperationalError, + code='HV004', name='FDW_INVALID_DATA_TYPE'): + pass + +class FdwColumnNameNotFound(OperationalError, + code='HV005', name='FDW_COLUMN_NAME_NOT_FOUND'): + pass + +class FdwInvalidDataTypeDescriptors(OperationalError, + code='HV006', name='FDW_INVALID_DATA_TYPE_DESCRIPTORS'): + pass + +class FdwInvalidColumnName(OperationalError, + code='HV007', name='FDW_INVALID_COLUMN_NAME'): + pass + +class FdwInvalidColumnNumber(OperationalError, + code='HV008', name='FDW_INVALID_COLUMN_NUMBER'): + pass + +class FdwInvalidUseOfNullPointer(OperationalError, + code='HV009', name='FDW_INVALID_USE_OF_NULL_POINTER'): + pass + +class FdwInvalidStringFormat(OperationalError, + code='HV00A', name='FDW_INVALID_STRING_FORMAT'): + pass + +class FdwInvalidHandle(OperationalError, + code='HV00B', name='FDW_INVALID_HANDLE'): + pass + +class FdwInvalidOptionIndex(OperationalError, + code='HV00C', name='FDW_INVALID_OPTION_INDEX'): + pass + +class FdwInvalidOptionName(OperationalError, + code='HV00D', name='FDW_INVALID_OPTION_NAME'): + pass + +class FdwOptionNameNotFound(OperationalError, + code='HV00J', name='FDW_OPTION_NAME_NOT_FOUND'): + pass + +class FdwReplyHandle(OperationalError, + code='HV00K', name='FDW_REPLY_HANDLE'): + pass + +class FdwUnableToCreateExecution(OperationalError, + code='HV00L', name='FDW_UNABLE_TO_CREATE_EXECUTION'): + pass + +class FdwUnableToCreateReply(OperationalError, + code='HV00M', name='FDW_UNABLE_TO_CREATE_REPLY'): + pass + +class FdwUnableToEstablishConnection(OperationalError, + code='HV00N', name='FDW_UNABLE_TO_ESTABLISH_CONNECTION'): + pass + +class FdwNoSchemas(OperationalError, + code='HV00P', name='FDW_NO_SCHEMAS'): + pass + +class FdwSchemaNotFound(OperationalError, + code='HV00Q', name='FDW_SCHEMA_NOT_FOUND'): + pass + +class FdwTableNotFound(OperationalError, + code='HV00R', name='FDW_TABLE_NOT_FOUND'): + pass + +class FdwFunctionSequenceError(OperationalError, + code='HV010', name='FDW_FUNCTION_SEQUENCE_ERROR'): + pass + +class FdwTooManyHandles(OperationalError, + code='HV014', name='FDW_TOO_MANY_HANDLES'): + pass + +class FdwInconsistentDescriptorInformation(OperationalError, + code='HV021', name='FDW_INCONSISTENT_DESCRIPTOR_INFORMATION'): + pass + +class FdwInvalidAttributeValue(OperationalError, + code='HV024', name='FDW_INVALID_ATTRIBUTE_VALUE'): + pass + +class FdwInvalidStringLengthOrBufferLength(OperationalError, + code='HV090', name='FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH'): + pass + +class FdwInvalidDescriptorFieldIdentifier(OperationalError, + code='HV091', name='FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER'): + pass + + +# Class P0 - PL/pgSQL Error + +class PlpgsqlError(ProgrammingError, + code='P0000', name='PLPGSQL_ERROR'): + pass + +class RaiseException(ProgrammingError, + code='P0001', name='RAISE_EXCEPTION'): + pass + +class NoDataFound(ProgrammingError, + code='P0002', name='NO_DATA_FOUND'): + pass + +class TooManyRows(ProgrammingError, + code='P0003', name='TOO_MANY_ROWS'): + pass + +class AssertFailure(ProgrammingError, + code='P0004', name='ASSERT_FAILURE'): + pass + + +# Class XX - Internal Error + +class InternalError_(InternalError, + code='XX000', name='INTERNAL_ERROR'): + pass + +class DataCorrupted(InternalError, + code='XX001', name='DATA_CORRUPTED'): + pass + +class IndexCorrupted(InternalError, + code='XX002', name='INDEX_CORRUPTED'): + pass + + +# autogenerated: end +# fmt: on diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py new file mode 100644 index 0000000..584fe47 --- /dev/null +++ b/psycopg/psycopg/generators.py @@ -0,0 +1,320 @@ +""" +Generators implementing communication protocols with the libpq + +Certain operations (connection, querying) are an interleave of libpq calls and +waiting for the socket to be ready. This module contains the code to execute +the operations, yielding a polling state whenever there is to wait. The +functions in the `waiting` module are the ones who wait more or less +cooperatively for the socket to be ready and make these generators continue. + +All these generators yield pairs (fileno, `Wait`) whenever an operation would +block. The generator can be restarted sending the appropriate `Ready` state +when the file descriptor is ready. + +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +from typing import List, Optional, Union + +from . import pq +from . import errors as e +from .abc import Buffer, PipelineCommand, PQGen, PQGenConn +from .pq.abc import PGconn, PGresult +from .waiting import Wait, Ready +from ._compat import Deque +from ._cmodule import _psycopg +from ._encodings import pgconn_encoding, conninfo_encoding + +OK = pq.ConnStatus.OK +BAD = pq.ConnStatus.BAD + +POLL_OK = pq.PollingStatus.OK +POLL_READING = pq.PollingStatus.READING +POLL_WRITING = pq.PollingStatus.WRITING +POLL_FAILED = pq.PollingStatus.FAILED + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +COPY_OUT = pq.ExecStatus.COPY_OUT +COPY_IN = pq.ExecStatus.COPY_IN +COPY_BOTH = pq.ExecStatus.COPY_BOTH +PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC + +WAIT_R = Wait.R +WAIT_W = Wait.W +WAIT_RW = Wait.RW +READY_R = Ready.R +READY_W = Ready.W +READY_RW = Ready.RW + +logger = logging.getLogger(__name__) + + +def _connect(conninfo: str) -> PQGenConn[PGconn]: + """ + Generator to create a database connection without blocking. + + """ + conn = pq.PGconn.connect_start(conninfo.encode()) + while True: + if conn.status == BAD: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection is bad: {pq.error_message(conn, encoding=encoding)}", + pgconn=conn, + ) + + status = conn.connect_poll() + if status == POLL_OK: + break + elif status == POLL_READING: + yield conn.socket, WAIT_R + elif status == POLL_WRITING: + yield conn.socket, WAIT_W + elif status == POLL_FAILED: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection failed: {pq.error_message(conn, encoding=encoding)}", + pgconn=conn, + ) + else: + raise e.InternalError(f"unexpected poll status: {status}", pgconn=conn) + + conn.nonblocking = 1 + return conn + + +def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]: + """ + Generator sending a query and returning results without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + Return the list of results returned by the database (whether success + or error). + """ + yield from _send(pgconn) + rv = yield from _fetch_many(pgconn) + return rv + + +def _send(pgconn: PGconn) -> PQGen[None]: + """ + Generator to send a query to the server without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + After this generator has finished you may want to cycle using `fetch()` + to retrieve the results available. + """ + while True: + f = pgconn.flush() + if f == 0: + break + + ready = yield WAIT_RW + if ready & READY_R: + # This call may read notifies: they will be saved in the + # PGconn buffer and passed to Python later, in `fetch()`. + pgconn.consume_input() + + +def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]: + """ + Generator retrieving results from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return the list of results returned by the database (whether success + or error). + """ + results: List[PGresult] = [] + while True: + res = yield from _fetch(pgconn) + if not res: + break + + results.append(res) + status = res.status + if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: + # After entering copy mode the libpq will create a phony result + # for every request so let's break the endless loop. + break + + if status == PIPELINE_SYNC: + # PIPELINE_SYNC is not followed by a NULL, but we return it alone + # similarly to other result sets. + assert len(results) == 1, results + break + + return results + + +def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]: + """ + Generator retrieving a single result from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return a result from the database (whether success or error). + """ + if pgconn.is_busy(): + yield WAIT_R + while True: + pgconn.consume_input() + if not pgconn.is_busy(): + break + yield WAIT_R + + _consume_notifies(pgconn) + + return pgconn.get_result() + + +def _pipeline_communicate( + pgconn: PGconn, commands: Deque[PipelineCommand] +) -> PQGen[List[List[PGresult]]]: + """Generator to send queries from a connection in pipeline mode while also + receiving results. + + Return a list results, including single PIPELINE_SYNC elements. + """ + results = [] + + while True: + ready = yield WAIT_RW + + if ready & READY_R: + pgconn.consume_input() + _consume_notifies(pgconn) + + res: List[PGresult] = [] + while not pgconn.is_busy(): + r = pgconn.get_result() + if r is None: + if not res: + break + results.append(res) + res = [] + elif r.status == PIPELINE_SYNC: + assert not res + results.append([r]) + else: + res.append(r) + + if ready & READY_W: + pgconn.flush() + if not commands: + break + commands.popleft()() + + return results + + +def _consume_notifies(pgconn: PGconn) -> None: + # Consume notifies + while True: + n = pgconn.notifies() + if not n: + break + if pgconn.notify_handler: + pgconn.notify_handler(n) + + +def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]: + yield WAIT_R + pgconn.consume_input() + + ns = [] + while True: + n = pgconn.notifies() + if n: + ns.append(n) + else: + break + + return ns + + +def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]: + while True: + nbytes, data = pgconn.get_copy_data(1) + if nbytes != 0: + break + + # would block + yield WAIT_R + pgconn.consume_input() + + if nbytes > 0: + # some data + return data + + # Retrieve the final result of copy + results = yield from _fetch_many(pgconn) + if len(results) > 1: + # TODO: too brutal? Copy worked. + raise e.ProgrammingError("you cannot mix COPY with other operations") + result = results[0] + if result.status != COMMAND_OK: + encoding = pgconn_encoding(pgconn) + raise e.error_from_result(result, encoding=encoding) + + return result + + +def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]: + # Retry enqueuing data until successful. + # + # WARNING! This can cause an infinite loop if the buffer is too large. (see + # ticket #255). We avoid it in the Copy object by splitting a large buffer + # into smaller ones. We prefer to do it there instead of here in order to + # do it upstream the queue decoupling the writer task from the producer one. + while pgconn.put_copy_data(buffer) == 0: + yield WAIT_W + + +def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]: + # Retry enqueuing end copy message until successful + while pgconn.put_copy_end(error) == 0: + yield WAIT_W + + # Repeat until it the message is flushed to the server + while True: + yield WAIT_W + f = pgconn.flush() + if f == 0: + break + + # Retrieve the final result of copy + (result,) = yield from _fetch_many(pgconn) + if result.status != COMMAND_OK: + encoding = pgconn_encoding(pgconn) + raise e.error_from_result(result, encoding=encoding) + + return result + + +# Override functions with fast versions if available +if _psycopg: + connect = _psycopg.connect + execute = _psycopg.execute + send = _psycopg.send + fetch_many = _psycopg.fetch_many + fetch = _psycopg.fetch + pipeline_communicate = _psycopg.pipeline_communicate + +else: + connect = _connect + execute = _execute + send = _send + fetch_many = _fetch_many + fetch = _fetch + pipeline_communicate = _pipeline_communicate diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py new file mode 100644 index 0000000..792a9c8 --- /dev/null +++ b/psycopg/psycopg/postgres.py @@ -0,0 +1,125 @@ +""" +Types configuration specific to PostgreSQL. +""" + +# Copyright (C) 2020 The Psycopg Team + +from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry +from .abc import AdaptContext +from ._adapters_map import AdaptersMap + +# Global objects with PostgreSQL builtins and globally registered user types. +types = TypesRegistry() + +# Global adapter maps with PostgreSQL types configuration +adapters = AdaptersMap(types=types) + +# Use tools/update_oids.py to update this data. +for t in [ + TypeInfo('"char"', 18, 1002), + # autogenerated: start + # Generated from PostgreSQL 15.0 + TypeInfo("aclitem", 1033, 1034), + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, regtype="boolean"), + TypeInfo("box", 603, 1020, delimiter=";"), + TypeInfo("bpchar", 1042, 1014, regtype="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("cid", 29, 1012), + TypeInfo("cidr", 650, 651), + TypeInfo("circle", 718, 719), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, regtype="real"), + TypeInfo("float8", 701, 1022, regtype="double precision"), + TypeInfo("gtsvector", 3642, 3644), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, regtype="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007, regtype="integer"), + TypeInfo("int8", 20, 1016, regtype="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("json", 114, 199), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("jsonpath", 4072, 4073), + TypeInfo("line", 628, 629), + TypeInfo("lseg", 601, 1018), + TypeInfo("macaddr", 829, 1040), + TypeInfo("macaddr8", 774, 775), + TypeInfo("money", 790, 791), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("path", 602, 1019), + TypeInfo("pg_lsn", 3220, 3221), + TypeInfo("point", 600, 1017), + TypeInfo("polygon", 604, 1027), + TypeInfo("record", 2249, 2287), + TypeInfo("refcursor", 1790, 2201), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regcollation", 4191, 4192), + TypeInfo("regconfig", 3734, 3735), + TypeInfo("regdictionary", 3769, 3770), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regoper", 2203, 2208), + TypeInfo("regoperator", 2204, 2209), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("tid", 27, 1010), + TypeInfo("time", 1083, 1183, regtype="time without time zone"), + TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"), + TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"), + TypeInfo("timetz", 1266, 1270, regtype="time with time zone"), + TypeInfo("tsquery", 3615, 3645), + TypeInfo("tsvector", 3614, 3643), + TypeInfo("txid_snapshot", 2970, 2949), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, regtype="bit varying"), + TypeInfo("varchar", 1043, 1015, regtype="character varying"), + TypeInfo("xid", 28, 1011), + TypeInfo("xid8", 5069, 271), + TypeInfo("xml", 142, 143), + RangeInfo("daterange", 3912, 3913, subtype_oid=1082), + RangeInfo("int4range", 3904, 3905, subtype_oid=23), + RangeInfo("int8range", 3926, 3927, subtype_oid=20), + RangeInfo("numrange", 3906, 3907, subtype_oid=1700), + RangeInfo("tsrange", 3908, 3909, subtype_oid=1114), + RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184), + MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082), + MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23), + MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20), + MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700), + MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114), + MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184), + # autogenerated: end +]: + types.add(t) + + +# A few oids used a bit everywhere +INVALID_OID = 0 +TEXT_OID = types["text"].oid +TEXT_ARRAY_OID = types["text"].array_oid + + +def register_default_adapters(context: AdaptContext) -> None: + + from .types import array, bool, composite, datetime, enum, json, multirange + from .types import net, none, numeric, range, string, uuid + + array.register_default_adapters(context) + bool.register_default_adapters(context) + composite.register_default_adapters(context) + datetime.register_default_adapters(context) + enum.register_default_adapters(context) + json.register_default_adapters(context) + multirange.register_default_adapters(context) + net.register_default_adapters(context) + none.register_default_adapters(context) + numeric.register_default_adapters(context) + range.register_default_adapters(context) + string.register_default_adapters(context) + uuid.register_default_adapters(context) diff --git a/psycopg/psycopg/pq/__init__.py b/psycopg/psycopg/pq/__init__.py new file mode 100644 index 0000000..d5180b1 --- /dev/null +++ b/psycopg/psycopg/pq/__init__.py @@ -0,0 +1,133 @@ +""" +psycopg libpq wrapper + +This package exposes the libpq functionalities as Python objects and functions. + +The real implementation (the binding to the C library) is +implementation-dependant but all the implementations share the same interface. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import logging +from typing import Callable, List, Type + +from . import abc +from .misc import ConninfoOption, PGnotify, PGresAttDesc +from .misc import error_message +from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace +from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus + +logger = logging.getLogger(__name__) + +__impl__: str +"""The currently loaded implementation of the `!psycopg.pq` package. + +Possible values include ``python``, ``c``, ``binary``. +""" + +__build_version__: int +"""The libpq version the C package was built with. + +A number in the same format of `~psycopg.ConnectionInfo.server_version` +representing the libpq used to build the speedup module (``c``, ``binary``) if +available. + +Certain features might not be available if the built version is too old. +""" + +version: Callable[[], int] +PGconn: Type[abc.PGconn] +PGresult: Type[abc.PGresult] +Conninfo: Type[abc.Conninfo] +Escaping: Type[abc.Escaping] +PGcancel: Type[abc.PGcancel] + + +def import_from_libpq() -> None: + """ + Import pq objects implementation from the best libpq wrapper available. + + If an implementation is requested try to import only it, otherwise + try to import the best implementation available. + """ + # import these names into the module on success as side effect + global __impl__, version, __build_version__ + global PGconn, PGresult, Conninfo, Escaping, PGcancel + + impl = os.environ.get("PSYCOPG_IMPL", "").lower() + module = None + attempts: List[str] = [] + + def handle_error(name: str, e: Exception) -> None: + if not impl: + msg = f"couldn't import psycopg '{name}' implementation: {e}" + logger.debug(msg) + attempts.append(msg) + else: + msg = f"couldn't import requested psycopg '{name}' implementation: {e}" + raise ImportError(msg) from e + + # The best implementation: fast but requires the system libpq installed + if not impl or impl == "c": + try: + from psycopg_c import pq as module # type: ignore + except Exception as e: + handle_error("c", e) + + # Second best implementation: fast and stand-alone + if not module and (not impl or impl == "binary"): + try: + from psycopg_binary import pq as module # type: ignore + except Exception as e: + handle_error("binary", e) + + # Pure Python implementation, slow and requires the system libpq installed. + if not module and (not impl or impl == "python"): + try: + from . import pq_ctypes as module # type: ignore[no-redef] + except Exception as e: + handle_error("python", e) + + if module: + __impl__ = module.__impl__ + version = module.version + PGconn = module.PGconn + PGresult = module.PGresult + Conninfo = module.Conninfo + Escaping = module.Escaping + PGcancel = module.PGcancel + __build_version__ = module.__build_version__ + elif impl: + raise ImportError(f"requested psycopg implementation '{impl}' unknown") + else: + sattempts = "\n".join(f"- {attempt}" for attempt in attempts) + raise ImportError( + f"""\ +no pq wrapper available. +Attempts made: +{sattempts}""" + ) + + +import_from_libpq() + +__all__ = ( + "ConnStatus", + "PipelineStatus", + "PollingStatus", + "TransactionStatus", + "ExecStatus", + "Ping", + "DiagnosticField", + "Format", + "Trace", + "PGconn", + "PGnotify", + "Conninfo", + "PGresAttDesc", + "error_message", + "ConninfoOption", + "version", +) diff --git a/psycopg/psycopg/pq/_debug.py b/psycopg/psycopg/pq/_debug.py new file mode 100644 index 0000000..f35d09f --- /dev/null +++ b/psycopg/psycopg/pq/_debug.py @@ -0,0 +1,106 @@ +""" +libpq debugging tools + +These functionalities are exposed here for convenience, but are not part of +the public interface and are subject to change at any moment. + +Suggested usage:: + + import logging + import psycopg + from psycopg import pq + from psycopg.pq._debug import PGconnDebug + + logging.basicConfig(level=logging.INFO, format="%(message)s") + logger = logging.getLogger("psycopg.debug") + logger.setLevel(logging.INFO) + + assert pq.__impl__ == "python" + pq.PGconn = PGconnDebug + + with psycopg.connect("") as conn: + conn.pgconn.trace(2) + conn.pgconn.set_trace_flags( + pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + ... + +""" + +# Copyright (C) 2022 The Psycopg Team + +import inspect +import logging +from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING +from functools import wraps + +from . import PGconn +from .misc import connection_summary + +if TYPE_CHECKING: + from . import abc + +Func = TypeVar("Func", bound=Callable[..., Any]) + +logger = logging.getLogger("psycopg.debug") + + +class PGconnDebug: + """Wrapper for a PQconn logging all its access.""" + + _Self = TypeVar("_Self", bound="PGconnDebug") + _pgconn: "abc.PGconn" + + def __init__(self, pgconn: "abc.PGconn"): + super().__setattr__("_pgconn", pgconn) + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self._pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + def __getattr__(self, attr: str) -> Any: + value = getattr(self._pgconn, attr) + if callable(value): + return debugging(value) + else: + logger.info("PGconn.%s -> %s", attr, value) + return value + + def __setattr__(self, attr: str, value: Any) -> None: + setattr(self._pgconn, attr, value) + logger.info("PGconn.%s <- %s", attr, value) + + @classmethod + def connect(cls: Type[_Self], conninfo: bytes) -> _Self: + return cls(debugging(PGconn.connect)(conninfo)) + + @classmethod + def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self: + return cls(debugging(PGconn.connect_start)(conninfo)) + + @classmethod + def ping(self, conninfo: bytes) -> int: + return debugging(PGconn.ping)(conninfo) + + +def debugging(f: Func) -> Func: + """Wrap a function in order to log its arguments and return value on call.""" + + @wraps(f) + def debugging_(*args: Any, **kwargs: Any) -> Any: + reprs = [] + for arg in args: + reprs.append(f"{arg!r}") + for (k, v) in kwargs.items(): + reprs.append(f"{k}={v!r}") + + logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs)) + rv = f(*args, **kwargs) + # Display the return value only if the function is declared to return + # something else than None. + ra = inspect.signature(f).return_annotation + if ra is not None or rv is not None: + logger.info(" <- %r", rv) + return rv + + return debugging_ # type: ignore diff --git a/psycopg/psycopg/pq/_enums.py b/psycopg/psycopg/pq/_enums.py new file mode 100644 index 0000000..e0d4018 --- /dev/null +++ b/psycopg/psycopg/pq/_enums.py @@ -0,0 +1,249 @@ +""" +libpq enum definitions for psycopg +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import IntEnum, IntFlag, auto + + +class ConnStatus(IntEnum): + """ + Current status of the connection. + """ + + __module__ = "psycopg.pq" + + OK = 0 + """The connection is in a working state.""" + BAD = auto() + """The connection is closed.""" + + STARTED = auto() + MADE = auto() + AWAITING_RESPONSE = auto() + AUTH_OK = auto() + SETENV = auto() + SSL_STARTUP = auto() + NEEDED = auto() + CHECK_WRITABLE = auto() + CONSUME = auto() + GSS_STARTUP = auto() + CHECK_TARGET = auto() + CHECK_STANDBY = auto() + + +class PollingStatus(IntEnum): + """ + The status of the socket during a connection. + + If ``READING`` or ``WRITING`` you may select before polling again. + """ + + __module__ = "psycopg.pq" + + FAILED = 0 + """Connection attempt failed.""" + READING = auto() + """Will have to wait before reading new data.""" + WRITING = auto() + """Will have to wait before writing new data.""" + OK = auto() + """Connection completed.""" + + ACTIVE = auto() + + +class ExecStatus(IntEnum): + """ + The status of a command. + """ + + __module__ = "psycopg.pq" + + EMPTY_QUERY = 0 + """The string sent to the server was empty.""" + + COMMAND_OK = auto() + """Successful completion of a command returning no data.""" + + TUPLES_OK = auto() + """ + Successful completion of a command returning data (such as a SELECT or SHOW). + """ + + COPY_OUT = auto() + """Copy Out (from server) data transfer started.""" + + COPY_IN = auto() + """Copy In (to server) data transfer started.""" + + BAD_RESPONSE = auto() + """The server's response was not understood.""" + + NONFATAL_ERROR = auto() + """A nonfatal error (a notice or warning) occurred.""" + + FATAL_ERROR = auto() + """A fatal error occurred.""" + + COPY_BOTH = auto() + """ + Copy In/Out (to and from server) data transfer started. + + This feature is currently used only for streaming replication, so this + status should not occur in ordinary applications. + """ + + SINGLE_TUPLE = auto() + """ + The PGresult contains a single result tuple from the current command. + + This status occurs only when single-row mode has been selected for the + query. + """ + + PIPELINE_SYNC = auto() + """ + The PGresult represents a synchronization point in pipeline mode, + requested by PQpipelineSync. + + This status occurs only when pipeline mode has been selected. + """ + + PIPELINE_ABORTED = auto() + """ + The PGresult represents a pipeline that has received an error from the server. + + PQgetResult must be called repeatedly, and each time it will return this + status code until the end of the current pipeline, at which point it will + return PGRES_PIPELINE_SYNC and normal processing can resume. + """ + + +class TransactionStatus(IntEnum): + """ + The transaction status of a connection. + """ + + __module__ = "psycopg.pq" + + IDLE = 0 + """Connection ready, no transaction active.""" + + ACTIVE = auto() + """A command is in progress.""" + + INTRANS = auto() + """Connection idle in an open transaction.""" + + INERROR = auto() + """An error happened in the current transaction.""" + + UNKNOWN = auto() + """Unknown connection state, broken connection.""" + + +class Ping(IntEnum): + """Response from a ping attempt.""" + + __module__ = "psycopg.pq" + + OK = 0 + """ + The server is running and appears to be accepting connections. + """ + + REJECT = auto() + """ + The server is running but is in a state that disallows connections. + """ + + NO_RESPONSE = auto() + """ + The server could not be contacted. + """ + + NO_ATTEMPT = auto() + """ + No attempt was made to contact the server. + """ + + +class PipelineStatus(IntEnum): + """Pipeline mode status of the libpq connection.""" + + __module__ = "psycopg.pq" + + OFF = 0 + """ + The libpq connection is *not* in pipeline mode. + """ + ON = auto() + """ + The libpq connection is in pipeline mode. + """ + ABORTED = auto() + """ + The libpq connection is in pipeline mode and an error occurred while + processing the current pipeline. The aborted flag is cleared when + PQgetResult returns a result of type PGRES_PIPELINE_SYNC. + """ + + +class DiagnosticField(IntEnum): + """ + Fields in an error report. + """ + + __module__ = "psycopg.pq" + + # from postgres_ext.h + SEVERITY = ord("S") + SEVERITY_NONLOCALIZED = ord("V") + SQLSTATE = ord("C") + MESSAGE_PRIMARY = ord("M") + MESSAGE_DETAIL = ord("D") + MESSAGE_HINT = ord("H") + STATEMENT_POSITION = ord("P") + INTERNAL_POSITION = ord("p") + INTERNAL_QUERY = ord("q") + CONTEXT = ord("W") + SCHEMA_NAME = ord("s") + TABLE_NAME = ord("t") + COLUMN_NAME = ord("c") + DATATYPE_NAME = ord("d") + CONSTRAINT_NAME = ord("n") + SOURCE_FILE = ord("F") + SOURCE_LINE = ord("L") + SOURCE_FUNCTION = ord("R") + + +class Format(IntEnum): + """ + Enum representing the format of a query argument or return value. + + These values are only the ones managed by the libpq. `~psycopg` may also + support automatically-chosen values: see `psycopg.adapt.PyFormat`. + """ + + __module__ = "psycopg.pq" + + TEXT = 0 + """Text parameter.""" + BINARY = 1 + """Binary parameter.""" + + +class Trace(IntFlag): + """ + Enum to control tracing of the client/server communication. + """ + + __module__ = "psycopg.pq" + + SUPPRESS_TIMESTAMPS = 1 + """Do not include timestamps in messages.""" + + REGRESS_MODE = 2 + """Redact some fields, e.g. OIDs, from messages.""" diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py new file mode 100644 index 0000000..9ca1d12 --- /dev/null +++ b/psycopg/psycopg/pq/_pq_ctypes.py @@ -0,0 +1,804 @@ +""" +libpq access using ctypes +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import ctypes +import ctypes.util +from ctypes import Structure, CFUNCTYPE, POINTER +from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p +from typing import List, Optional, Tuple + +from .misc import find_libpq_full_path +from ..errors import NotSupportedError + +libname = find_libpq_full_path() +if not libname: + raise ImportError("libpq library not found") + +pq = ctypes.cdll.LoadLibrary(libname) + + +class FILE(Structure): + pass + + +FILE_ptr = POINTER(FILE) + +if sys.platform == "linux": + libcname = ctypes.util.find_library("c") + assert libcname + libc = ctypes.cdll.LoadLibrary(libcname) + + fdopen = libc.fdopen + fdopen.argtypes = (c_int, c_char_p) + fdopen.restype = FILE_ptr + + +# Get the libpq version to define what functions are available. + +PQlibVersion = pq.PQlibVersion +PQlibVersion.argtypes = [] +PQlibVersion.restype = c_int + +libpq_version = PQlibVersion() + + +# libpq data types + + +Oid = c_uint + + +class PGconn_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PGresult_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PQconninfoOption_struct(Structure): + _fields_ = [ + ("keyword", c_char_p), + ("envvar", c_char_p), + ("compiled", c_char_p), + ("val", c_char_p), + ("label", c_char_p), + ("dispchar", c_char_p), + ("dispsize", c_int), + ] + + +class PGnotify_struct(Structure): + _fields_ = [ + ("relname", c_char_p), + ("be_pid", c_int), + ("extra", c_char_p), + ] + + +class PGcancel_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PGresAttDesc_struct(Structure): + _fields_ = [ + ("name", c_char_p), + ("tableid", Oid), + ("columnid", c_int), + ("format", c_int), + ("typid", Oid), + ("typlen", c_int), + ("atttypmod", c_int), + ] + + +PGconn_ptr = POINTER(PGconn_struct) +PGresult_ptr = POINTER(PGresult_struct) +PQconninfoOption_ptr = POINTER(PQconninfoOption_struct) +PGnotify_ptr = POINTER(PGnotify_struct) +PGcancel_ptr = POINTER(PGcancel_struct) +PGresAttDesc_ptr = POINTER(PGresAttDesc_struct) + + +# Function definitions as explained in PostgreSQL 12 documentation + +# 33.1. Database Connection Control Functions + +# PQconnectdbParams: doesn't seem useful, won't wrap for now + +PQconnectdb = pq.PQconnectdb +PQconnectdb.argtypes = [c_char_p] +PQconnectdb.restype = PGconn_ptr + +# PQsetdbLogin: not useful +# PQsetdb: not useful + +# PQconnectStartParams: not useful + +PQconnectStart = pq.PQconnectStart +PQconnectStart.argtypes = [c_char_p] +PQconnectStart.restype = PGconn_ptr + +PQconnectPoll = pq.PQconnectPoll +PQconnectPoll.argtypes = [PGconn_ptr] +PQconnectPoll.restype = c_int + +PQconndefaults = pq.PQconndefaults +PQconndefaults.argtypes = [] +PQconndefaults.restype = PQconninfoOption_ptr + +PQconninfoFree = pq.PQconninfoFree +PQconninfoFree.argtypes = [PQconninfoOption_ptr] +PQconninfoFree.restype = None + +PQconninfo = pq.PQconninfo +PQconninfo.argtypes = [PGconn_ptr] +PQconninfo.restype = PQconninfoOption_ptr + +PQconninfoParse = pq.PQconninfoParse +PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)] +PQconninfoParse.restype = PQconninfoOption_ptr + +PQfinish = pq.PQfinish +PQfinish.argtypes = [PGconn_ptr] +PQfinish.restype = None + +PQreset = pq.PQreset +PQreset.argtypes = [PGconn_ptr] +PQreset.restype = None + +PQresetStart = pq.PQresetStart +PQresetStart.argtypes = [PGconn_ptr] +PQresetStart.restype = c_int + +PQresetPoll = pq.PQresetPoll +PQresetPoll.argtypes = [PGconn_ptr] +PQresetPoll.restype = c_int + +PQping = pq.PQping +PQping.argtypes = [c_char_p] +PQping.restype = c_int + + +# 33.2. Connection Status Functions + +PQdb = pq.PQdb +PQdb.argtypes = [PGconn_ptr] +PQdb.restype = c_char_p + +PQuser = pq.PQuser +PQuser.argtypes = [PGconn_ptr] +PQuser.restype = c_char_p + +PQpass = pq.PQpass +PQpass.argtypes = [PGconn_ptr] +PQpass.restype = c_char_p + +PQhost = pq.PQhost +PQhost.argtypes = [PGconn_ptr] +PQhost.restype = c_char_p + +_PQhostaddr = None + +if libpq_version >= 120000: + _PQhostaddr = pq.PQhostaddr + _PQhostaddr.argtypes = [PGconn_ptr] + _PQhostaddr.restype = c_char_p + + +def PQhostaddr(pgconn: PGconn_struct) -> bytes: + if not _PQhostaddr: + raise NotSupportedError( + "PQhostaddr requires libpq from PostgreSQL 12," + f" {libpq_version} available instead" + ) + + return _PQhostaddr(pgconn) + + +PQport = pq.PQport +PQport.argtypes = [PGconn_ptr] +PQport.restype = c_char_p + +PQtty = pq.PQtty +PQtty.argtypes = [PGconn_ptr] +PQtty.restype = c_char_p + +PQoptions = pq.PQoptions +PQoptions.argtypes = [PGconn_ptr] +PQoptions.restype = c_char_p + +PQstatus = pq.PQstatus +PQstatus.argtypes = [PGconn_ptr] +PQstatus.restype = c_int + +PQtransactionStatus = pq.PQtransactionStatus +PQtransactionStatus.argtypes = [PGconn_ptr] +PQtransactionStatus.restype = c_int + +PQparameterStatus = pq.PQparameterStatus +PQparameterStatus.argtypes = [PGconn_ptr, c_char_p] +PQparameterStatus.restype = c_char_p + +PQprotocolVersion = pq.PQprotocolVersion +PQprotocolVersion.argtypes = [PGconn_ptr] +PQprotocolVersion.restype = c_int + +PQserverVersion = pq.PQserverVersion +PQserverVersion.argtypes = [PGconn_ptr] +PQserverVersion.restype = c_int + +PQerrorMessage = pq.PQerrorMessage +PQerrorMessage.argtypes = [PGconn_ptr] +PQerrorMessage.restype = c_char_p + +PQsocket = pq.PQsocket +PQsocket.argtypes = [PGconn_ptr] +PQsocket.restype = c_int + +PQbackendPID = pq.PQbackendPID +PQbackendPID.argtypes = [PGconn_ptr] +PQbackendPID.restype = c_int + +PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword +PQconnectionNeedsPassword.argtypes = [PGconn_ptr] +PQconnectionNeedsPassword.restype = c_int + +PQconnectionUsedPassword = pq.PQconnectionUsedPassword +PQconnectionUsedPassword.argtypes = [PGconn_ptr] +PQconnectionUsedPassword.restype = c_int + +PQsslInUse = pq.PQsslInUse +PQsslInUse.argtypes = [PGconn_ptr] +PQsslInUse.restype = c_int + +# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl + + +# 33.3. Command Execution Functions + +PQexec = pq.PQexec +PQexec.argtypes = [PGconn_ptr, c_char_p] +PQexec.restype = PGresult_ptr + +PQexecParams = pq.PQexecParams +PQexecParams.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(Oid), + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQexecParams.restype = PGresult_ptr + +PQprepare = pq.PQprepare +PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)] +PQprepare.restype = PGresult_ptr + +PQexecPrepared = pq.PQexecPrepared +PQexecPrepared.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQexecPrepared.restype = PGresult_ptr + +PQdescribePrepared = pq.PQdescribePrepared +PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p] +PQdescribePrepared.restype = PGresult_ptr + +PQdescribePortal = pq.PQdescribePortal +PQdescribePortal.argtypes = [PGconn_ptr, c_char_p] +PQdescribePortal.restype = PGresult_ptr + +PQresultStatus = pq.PQresultStatus +PQresultStatus.argtypes = [PGresult_ptr] +PQresultStatus.restype = c_int + +# PQresStatus: not needed, we have pretty enums + +PQresultErrorMessage = pq.PQresultErrorMessage +PQresultErrorMessage.argtypes = [PGresult_ptr] +PQresultErrorMessage.restype = c_char_p + +# TODO: PQresultVerboseErrorMessage + +PQresultErrorField = pq.PQresultErrorField +PQresultErrorField.argtypes = [PGresult_ptr, c_int] +PQresultErrorField.restype = c_char_p + +PQclear = pq.PQclear +PQclear.argtypes = [PGresult_ptr] +PQclear.restype = None + + +# 33.3.2. Retrieving Query Result Information + +PQntuples = pq.PQntuples +PQntuples.argtypes = [PGresult_ptr] +PQntuples.restype = c_int + +PQnfields = pq.PQnfields +PQnfields.argtypes = [PGresult_ptr] +PQnfields.restype = c_int + +PQfname = pq.PQfname +PQfname.argtypes = [PGresult_ptr, c_int] +PQfname.restype = c_char_p + +# PQfnumber: useless and hard to use + +PQftable = pq.PQftable +PQftable.argtypes = [PGresult_ptr, c_int] +PQftable.restype = Oid + +PQftablecol = pq.PQftablecol +PQftablecol.argtypes = [PGresult_ptr, c_int] +PQftablecol.restype = c_int + +PQfformat = pq.PQfformat +PQfformat.argtypes = [PGresult_ptr, c_int] +PQfformat.restype = c_int + +PQftype = pq.PQftype +PQftype.argtypes = [PGresult_ptr, c_int] +PQftype.restype = Oid + +PQfmod = pq.PQfmod +PQfmod.argtypes = [PGresult_ptr, c_int] +PQfmod.restype = c_int + +PQfsize = pq.PQfsize +PQfsize.argtypes = [PGresult_ptr, c_int] +PQfsize.restype = c_int + +PQbinaryTuples = pq.PQbinaryTuples +PQbinaryTuples.argtypes = [PGresult_ptr] +PQbinaryTuples.restype = c_int + +PQgetvalue = pq.PQgetvalue +PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int] +PQgetvalue.restype = POINTER(c_char) # not a null-terminated string + +PQgetisnull = pq.PQgetisnull +PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int] +PQgetisnull.restype = c_int + +PQgetlength = pq.PQgetlength +PQgetlength.argtypes = [PGresult_ptr, c_int, c_int] +PQgetlength.restype = c_int + +PQnparams = pq.PQnparams +PQnparams.argtypes = [PGresult_ptr] +PQnparams.restype = c_int + +PQparamtype = pq.PQparamtype +PQparamtype.argtypes = [PGresult_ptr, c_int] +PQparamtype.restype = Oid + +# PQprint: pretty useless + +# 33.3.3. Retrieving Other Result Information + +PQcmdStatus = pq.PQcmdStatus +PQcmdStatus.argtypes = [PGresult_ptr] +PQcmdStatus.restype = c_char_p + +PQcmdTuples = pq.PQcmdTuples +PQcmdTuples.argtypes = [PGresult_ptr] +PQcmdTuples.restype = c_char_p + +PQoidValue = pq.PQoidValue +PQoidValue.argtypes = [PGresult_ptr] +PQoidValue.restype = Oid + + +# 33.3.4. Escaping Strings for Inclusion in SQL Commands + +PQescapeLiteral = pq.PQescapeLiteral +PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t] +PQescapeLiteral.restype = POINTER(c_char) + +PQescapeIdentifier = pq.PQescapeIdentifier +PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t] +PQescapeIdentifier.restype = POINTER(c_char) + +PQescapeStringConn = pq.PQescapeStringConn +# TODO: raises "wrong type" error +# PQescapeStringConn.argtypes = [ +# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int) +# ] +PQescapeStringConn.restype = c_size_t + +PQescapeString = pq.PQescapeString +# TODO: raises "wrong type" error +# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t] +PQescapeString.restype = c_size_t + +PQescapeByteaConn = pq.PQescapeByteaConn +PQescapeByteaConn.argtypes = [ + PGconn_ptr, + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + c_size_t, + POINTER(c_size_t), +] +PQescapeByteaConn.restype = POINTER(c_ubyte) + +PQescapeBytea = pq.PQescapeBytea +PQescapeBytea.argtypes = [ + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + c_size_t, + POINTER(c_size_t), +] +PQescapeBytea.restype = POINTER(c_ubyte) + + +PQunescapeBytea = pq.PQunescapeBytea +PQunescapeBytea.argtypes = [ + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + POINTER(c_size_t), +] +PQunescapeBytea.restype = POINTER(c_ubyte) + + +# 33.4. Asynchronous Command Processing + +PQsendQuery = pq.PQsendQuery +PQsendQuery.argtypes = [PGconn_ptr, c_char_p] +PQsendQuery.restype = c_int + +PQsendQueryParams = pq.PQsendQueryParams +PQsendQueryParams.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(Oid), + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQsendQueryParams.restype = c_int + +PQsendPrepare = pq.PQsendPrepare +PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)] +PQsendPrepare.restype = c_int + +PQsendQueryPrepared = pq.PQsendQueryPrepared +PQsendQueryPrepared.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQsendQueryPrepared.restype = c_int + +PQsendDescribePrepared = pq.PQsendDescribePrepared +PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p] +PQsendDescribePrepared.restype = c_int + +PQsendDescribePortal = pq.PQsendDescribePortal +PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p] +PQsendDescribePortal.restype = c_int + +PQgetResult = pq.PQgetResult +PQgetResult.argtypes = [PGconn_ptr] +PQgetResult.restype = PGresult_ptr + +PQconsumeInput = pq.PQconsumeInput +PQconsumeInput.argtypes = [PGconn_ptr] +PQconsumeInput.restype = c_int + +PQisBusy = pq.PQisBusy +PQisBusy.argtypes = [PGconn_ptr] +PQisBusy.restype = c_int + +PQsetnonblocking = pq.PQsetnonblocking +PQsetnonblocking.argtypes = [PGconn_ptr, c_int] +PQsetnonblocking.restype = c_int + +PQisnonblocking = pq.PQisnonblocking +PQisnonblocking.argtypes = [PGconn_ptr] +PQisnonblocking.restype = c_int + +PQflush = pq.PQflush +PQflush.argtypes = [PGconn_ptr] +PQflush.restype = c_int + + +# 33.5. Retrieving Query Results Row-by-Row +PQsetSingleRowMode = pq.PQsetSingleRowMode +PQsetSingleRowMode.argtypes = [PGconn_ptr] +PQsetSingleRowMode.restype = c_int + + +# 33.6. Canceling Queries in Progress + +PQgetCancel = pq.PQgetCancel +PQgetCancel.argtypes = [PGconn_ptr] +PQgetCancel.restype = PGcancel_ptr + +PQfreeCancel = pq.PQfreeCancel +PQfreeCancel.argtypes = [PGcancel_ptr] +PQfreeCancel.restype = None + +PQcancel = pq.PQcancel +# TODO: raises "wrong type" error +# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int] +PQcancel.restype = c_int + + +# 33.8. Asynchronous Notification + +PQnotifies = pq.PQnotifies +PQnotifies.argtypes = [PGconn_ptr] +PQnotifies.restype = PGnotify_ptr + + +# 33.9. Functions Associated with the COPY Command + +PQputCopyData = pq.PQputCopyData +PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int] +PQputCopyData.restype = c_int + +PQputCopyEnd = pq.PQputCopyEnd +PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p] +PQputCopyEnd.restype = c_int + +PQgetCopyData = pq.PQgetCopyData +PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int] +PQgetCopyData.restype = c_int + + +# 33.10. Control Functions + +PQtrace = pq.PQtrace +PQtrace.argtypes = [PGconn_ptr, FILE_ptr] +PQtrace.restype = None + +_PQsetTraceFlags = None + +if libpq_version >= 140000: + _PQsetTraceFlags = pq.PQsetTraceFlags + _PQsetTraceFlags.argtypes = [PGconn_ptr, c_int] + _PQsetTraceFlags.restype = None + + +def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None: + if not _PQsetTraceFlags: + raise NotSupportedError( + "PQsetTraceFlags requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + + _PQsetTraceFlags(pgconn, flags) + + +PQuntrace = pq.PQuntrace +PQuntrace.argtypes = [PGconn_ptr] +PQuntrace.restype = None + +# 33.11. Miscellaneous Functions + +PQfreemem = pq.PQfreemem +PQfreemem.argtypes = [c_void_p] +PQfreemem.restype = None + +if libpq_version >= 100000: + _PQencryptPasswordConn = pq.PQencryptPasswordConn + _PQencryptPasswordConn.argtypes = [ + PGconn_ptr, + c_char_p, + c_char_p, + c_char_p, + ] + _PQencryptPasswordConn.restype = POINTER(c_char) + + +def PQencryptPasswordConn( + pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes +) -> Optional[bytes]: + if not _PQencryptPasswordConn: + raise NotSupportedError( + "PQencryptPasswordConn requires libpq from PostgreSQL 10," + f" {libpq_version} available instead" + ) + + return _PQencryptPasswordConn(pgconn, passwd, user, algorithm) + + +PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult +PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int] +PQmakeEmptyPGresult.restype = PGresult_ptr + +PQsetResultAttrs = pq.PQsetResultAttrs +PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr] +PQsetResultAttrs.restype = c_int + + +# 33.12. Notice Processing + +PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr) + +PQsetNoticeReceiver = pq.PQsetNoticeReceiver +PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p] +PQsetNoticeReceiver.restype = PQnoticeReceiver + +# 34.5 Pipeline Mode + +_PQpipelineStatus = None +_PQenterPipelineMode = None +_PQexitPipelineMode = None +_PQpipelineSync = None +_PQsendFlushRequest = None + +if libpq_version >= 140000: + _PQpipelineStatus = pq.PQpipelineStatus + _PQpipelineStatus.argtypes = [PGconn_ptr] + _PQpipelineStatus.restype = c_int + + _PQenterPipelineMode = pq.PQenterPipelineMode + _PQenterPipelineMode.argtypes = [PGconn_ptr] + _PQenterPipelineMode.restype = c_int + + _PQexitPipelineMode = pq.PQexitPipelineMode + _PQexitPipelineMode.argtypes = [PGconn_ptr] + _PQexitPipelineMode.restype = c_int + + _PQpipelineSync = pq.PQpipelineSync + _PQpipelineSync.argtypes = [PGconn_ptr] + _PQpipelineSync.restype = c_int + + _PQsendFlushRequest = pq.PQsendFlushRequest + _PQsendFlushRequest.argtypes = [PGconn_ptr] + _PQsendFlushRequest.restype = c_int + + +def PQpipelineStatus(pgconn: PGconn_struct) -> int: + if not _PQpipelineStatus: + raise NotSupportedError( + "PQpipelineStatus requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQpipelineStatus(pgconn) + + +def PQenterPipelineMode(pgconn: PGconn_struct) -> int: + if not _PQenterPipelineMode: + raise NotSupportedError( + "PQenterPipelineMode requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQenterPipelineMode(pgconn) + + +def PQexitPipelineMode(pgconn: PGconn_struct) -> int: + if not _PQexitPipelineMode: + raise NotSupportedError( + "PQexitPipelineMode requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQexitPipelineMode(pgconn) + + +def PQpipelineSync(pgconn: PGconn_struct) -> int: + if not _PQpipelineSync: + raise NotSupportedError( + "PQpipelineSync requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQpipelineSync(pgconn) + + +def PQsendFlushRequest(pgconn: PGconn_struct) -> int: + if not _PQsendFlushRequest: + raise NotSupportedError( + "PQsendFlushRequest requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQsendFlushRequest(pgconn) + + +# 33.18. SSL Support + +PQinitOpenSSL = pq.PQinitOpenSSL +PQinitOpenSSL.argtypes = [c_int, c_int] +PQinitOpenSSL.restype = None + + +def generate_stub() -> None: + import re + from ctypes import _CFuncPtr # type: ignore + + def type2str(fname, narg, t): + if t is None: + return "None" + elif t is c_void_p: + return "Any" + elif t is c_int or t is c_uint or t is c_size_t: + return "int" + elif t is c_char_p or t.__name__ == "LP_c_char": + if narg is not None: + return "bytes" + else: + return "Optional[bytes]" + + elif t.__name__ in ( + "LP_PGconn_struct", + "LP_PGresult_struct", + "LP_PGcancel_struct", + ): + if narg is not None: + return f"Optional[{t.__name__[3:]}]" + else: + return t.__name__[3:] + + elif t.__name__ in ("LP_PQconninfoOption_struct",): + return f"Sequence[{t.__name__[3:]}]" + + elif t.__name__ in ( + "LP_c_ubyte", + "LP_c_char_p", + "LP_c_int", + "LP_c_uint", + "LP_c_ulong", + "LP_FILE", + ): + return f"_Pointer[{t.__name__[3:]}]" + + else: + assert False, f"can't deal with {t} in {fname}" + + fn = __file__ + "i" + with open(fn) as f: + lines = f.read().splitlines() + + istart, iend = ( + i + for i, line in enumerate(lines) + if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line) + ) + + known = { + line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ") + } + + signatures = [] + + for name, obj in globals().items(): + if name in known: + continue + if not isinstance(obj, _CFuncPtr): + continue + + params = [] + for i, t in enumerate(obj.argtypes): + params.append(f"arg{i + 1}: {type2str(name, i, t)}") + + resname = type2str(name, None, obj.restype) + + signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...") + + lines[istart + 1 : iend] = signatures + + with open(fn, "w") as f: + f.write("\n".join(lines)) + f.write("\n") + + +if __name__ == "__main__": + generate_stub() diff --git a/psycopg/psycopg/pq/_pq_ctypes.pyi b/psycopg/psycopg/pq/_pq_ctypes.pyi new file mode 100644 index 0000000..5d2ee3f --- /dev/null +++ b/psycopg/psycopg/pq/_pq_ctypes.pyi @@ -0,0 +1,216 @@ +""" +types stub for ctypes functions +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, Optional, Sequence +from ctypes import Array, pointer, _Pointer +from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong + +class FILE: ... + +def fdopen(fd: int, mode: bytes) -> _Pointer[FILE]: ... # type: ignore[type-var] + +Oid = c_uint + +class PGconn_struct: ... +class PGresult_struct: ... +class PGcancel_struct: ... + +class PQconninfoOption_struct: + keyword: bytes + envvar: bytes + compiled: bytes + val: bytes + label: bytes + dispchar: bytes + dispsize: int + +class PGnotify_struct: + be_pid: int + relname: bytes + extra: bytes + +class PGresAttDesc_struct: + name: bytes + tableid: int + columnid: int + format: int + typid: int + typlen: int + atttypmod: int + +def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQexecPrepared( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: int, + arg4: Optional[Array[c_char_p]], + arg5: Optional[Array[c_int]], + arg6: Optional[Array[c_int]], + arg7: int, +) -> PGresult_struct: ... +def PQprepare( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: int, + arg5: Optional[Array[c_uint]], +) -> PGresult_struct: ... +def PQgetvalue( + arg1: Optional[PGresult_struct], arg2: int, arg3: int +) -> _Pointer[c_char]: ... +def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQescapeStringConn( + arg1: Optional[PGconn_struct], + arg2: c_char_p, + arg3: bytes, + arg4: int, + arg5: _Pointer[c_int], +) -> int: ... +def PQescapeString(arg1: c_char_p, arg2: bytes, arg3: int) -> int: ... +def PQsendPrepare( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: int, + arg5: Optional[Array[c_uint]], +) -> int: ... +def PQsendQueryPrepared( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: int, + arg4: Optional[Array[c_char_p]], + arg5: Optional[Array[c_int]], + arg6: Optional[Array[c_int]], + arg7: int, +) -> int: ... +def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ... +def PQsetNoticeReceiver( + arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any +) -> Callable[[Any], PGresult_struct]: ... + +# TODO: Ignoring type as getting an error on mypy/ctypes: +# Type argument "psycopg.pq._pq_ctypes.PGnotify_struct" of "pointer" must be +# a subtype of "ctypes._CData" +def PQnotifies( + arg1: Optional[PGconn_struct], +) -> Optional[_Pointer[PGnotify_struct]]: ... # type: ignore +def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ... + +# Arg 2 is a _Pointer, reported as _CArgObject by mypy +def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ... +def PQsetResultAttrs( + arg1: Optional[PGresult_struct], + arg2: int, + arg3: Array[PGresAttDesc_struct], # type: ignore +) -> int: ... +def PQtrace( + arg1: Optional[PGconn_struct], + arg2: _Pointer[FILE], # type: ignore[type-var] +) -> None: ... +def PQencryptPasswordConn( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: Optional[bytes], +) -> bytes: ... +def PQpipelineStatus(pgconn: Optional[PGconn_struct]) -> int: ... +def PQenterPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ... +def PQexitPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ... +def PQpipelineSync(pgconn: Optional[PGconn_struct]) -> int: ... +def PQsendFlushRequest(pgconn: Optional[PGconn_struct]) -> int: ... + +# fmt: off +# autogenerated: start +def PQlibVersion() -> int: ... +def PQconnectdb(arg1: bytes) -> PGconn_struct: ... +def PQconnectStart(arg1: bytes) -> PGconn_struct: ... +def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ... +def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ... +def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ... +def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ... +def PQconninfoParse(arg1: bytes, arg2: _Pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ... +def PQfinish(arg1: Optional[PGconn_struct]) -> None: ... +def PQreset(arg1: Optional[PGconn_struct]) -> None: ... +def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ... +def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ... +def PQping(arg1: bytes) -> int: ... +def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQstatus(arg1: Optional[PGconn_struct]) -> int: ... +def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ... +def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ... +def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ... +def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ... +def PQsocket(arg1: Optional[PGconn_struct]) -> int: ... +def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ... +def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ... +def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ... +def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ... +def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> PGresult_struct: ... +def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ... +def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... +def PQclear(arg1: Optional[PGresult_struct]) -> None: ... +def PQntuples(arg1: Optional[PGresult_struct]) -> int: ... +def PQnfields(arg1: Optional[PGresult_struct]) -> int: ... +def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... +def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ... +def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... +def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... +def PQnparams(arg1: Optional[PGresult_struct]) -> int: ... +def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ... +def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ... +def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... +def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... +def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQescapeBytea(arg1: bytes, arg2: int, arg3: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQunescapeBytea(arg1: bytes, arg2: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> int: ... +def PQsendDescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQsendDescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ... +def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ... +def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ... +def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ... +def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ... +def PQflush(arg1: Optional[PGconn_struct]) -> int: ... +def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ... +def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ... +def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ... +def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ... +def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ... +def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ... +def PQfreemem(arg1: Any) -> None: ... +def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ... +def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ... +def _PQpipelineStatus(arg1: Optional[PGconn_struct]) -> int: ... +def _PQenterPipelineMode(arg1: Optional[PGconn_struct]) -> int: ... +def _PQexitPipelineMode(arg1: Optional[PGconn_struct]) -> int: ... +def _PQpipelineSync(arg1: Optional[PGconn_struct]) -> int: ... +def _PQsendFlushRequest(arg1: Optional[PGconn_struct]) -> int: ... +def PQinitOpenSSL(arg1: int, arg2: int) -> None: ... +# autogenerated: end +# fmt: on + +# vim: set syntax=python: diff --git a/psycopg/psycopg/pq/abc.py b/psycopg/psycopg/pq/abc.py new file mode 100644 index 0000000..9c45f64 --- /dev/null +++ b/psycopg/psycopg/pq/abc.py @@ -0,0 +1,385 @@ +""" +Protocol objects to represent objects exposed by different pq implementations. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from ._enums import Format, Trace +from .._compat import Protocol + +if TYPE_CHECKING: + from .misc import PGnotify, ConninfoOption, PGresAttDesc + +# An object implementing the buffer protocol (ish) +Buffer: TypeAlias = Union[bytes, bytearray, memoryview] + + +class PGconn(Protocol): + + notice_handler: Optional[Callable[["PGresult"], None]] + notify_handler: Optional[Callable[["PGnotify"], None]] + + @classmethod + def connect(cls, conninfo: bytes) -> "PGconn": + ... + + @classmethod + def connect_start(cls, conninfo: bytes) -> "PGconn": + ... + + def connect_poll(self) -> int: + ... + + def finish(self) -> None: + ... + + @property + def info(self) -> List["ConninfoOption"]: + ... + + def reset(self) -> None: + ... + + def reset_start(self) -> None: + ... + + def reset_poll(self) -> int: + ... + + @classmethod + def ping(self, conninfo: bytes) -> int: + ... + + @property + def db(self) -> bytes: + ... + + @property + def user(self) -> bytes: + ... + + @property + def password(self) -> bytes: + ... + + @property + def host(self) -> bytes: + ... + + @property + def hostaddr(self) -> bytes: + ... + + @property + def port(self) -> bytes: + ... + + @property + def tty(self) -> bytes: + ... + + @property + def options(self) -> bytes: + ... + + @property + def status(self) -> int: + ... + + @property + def transaction_status(self) -> int: + ... + + def parameter_status(self, name: bytes) -> Optional[bytes]: + ... + + @property + def error_message(self) -> bytes: + ... + + @property + def server_version(self) -> int: + ... + + @property + def socket(self) -> int: + ... + + @property + def backend_pid(self) -> int: + ... + + @property + def needs_password(self) -> bool: + ... + + @property + def used_password(self) -> bool: + ... + + @property + def ssl_in_use(self) -> bool: + ... + + def exec_(self, command: bytes) -> "PGresult": + ... + + def send_query(self, command: bytes) -> None: + ... + + def exec_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> "PGresult": + ... + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + ... + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + ... + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + ... + + def prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> "PGresult": + ... + + def exec_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Buffer]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = 0, + ) -> "PGresult": + ... + + def describe_prepared(self, name: bytes) -> "PGresult": + ... + + def send_describe_prepared(self, name: bytes) -> None: + ... + + def describe_portal(self, name: bytes) -> "PGresult": + ... + + def send_describe_portal(self, name: bytes) -> None: + ... + + def get_result(self) -> Optional["PGresult"]: + ... + + def consume_input(self) -> None: + ... + + def is_busy(self) -> int: + ... + + @property + def nonblocking(self) -> int: + ... + + @nonblocking.setter + def nonblocking(self, arg: int) -> None: + ... + + def flush(self) -> int: + ... + + def set_single_row_mode(self) -> None: + ... + + def get_cancel(self) -> "PGcancel": + ... + + def notifies(self) -> Optional["PGnotify"]: + ... + + def put_copy_data(self, buffer: Buffer) -> int: + ... + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + ... + + def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: + ... + + def trace(self, fileno: int) -> None: + ... + + def set_trace_flags(self, flags: Trace) -> None: + ... + + def untrace(self) -> None: + ... + + def encrypt_password( + self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None + ) -> bytes: + ... + + def make_empty_result(self, exec_status: int) -> "PGresult": + ... + + @property + def pipeline_status(self) -> int: + ... + + def enter_pipeline_mode(self) -> None: + ... + + def exit_pipeline_mode(self) -> None: + ... + + def pipeline_sync(self) -> None: + ... + + def send_flush_request(self) -> None: + ... + + +class PGresult(Protocol): + def clear(self) -> None: + ... + + @property + def status(self) -> int: + ... + + @property + def error_message(self) -> bytes: + ... + + def error_field(self, fieldcode: int) -> Optional[bytes]: + ... + + @property + def ntuples(self) -> int: + ... + + @property + def nfields(self) -> int: + ... + + def fname(self, column_number: int) -> Optional[bytes]: + ... + + def ftable(self, column_number: int) -> int: + ... + + def ftablecol(self, column_number: int) -> int: + ... + + def fformat(self, column_number: int) -> int: + ... + + def ftype(self, column_number: int) -> int: + ... + + def fmod(self, column_number: int) -> int: + ... + + def fsize(self, column_number: int) -> int: + ... + + @property + def binary_tuples(self) -> int: + ... + + def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: + ... + + @property + def nparams(self) -> int: + ... + + def param_type(self, param_number: int) -> int: + ... + + @property + def command_status(self) -> Optional[bytes]: + ... + + @property + def command_tuples(self) -> Optional[int]: + ... + + @property + def oid_value(self) -> int: + ... + + def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: + ... + + +class PGcancel(Protocol): + def free(self) -> None: + ... + + def cancel(self) -> None: + ... + + +class Conninfo(Protocol): + @classmethod + def get_defaults(cls) -> List["ConninfoOption"]: + ... + + @classmethod + def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: + ... + + @classmethod + def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: + ... + + +class Escaping(Protocol): + def __init__(self, conn: Optional[PGconn] = None): + ... + + def escape_literal(self, data: Buffer) -> bytes: + ... + + def escape_identifier(self, data: Buffer) -> bytes: + ... + + def escape_string(self, data: Buffer) -> bytes: + ... + + def escape_bytea(self, data: Buffer) -> bytes: + ... + + def unescape_bytea(self, data: Buffer) -> bytes: + ... diff --git a/psycopg/psycopg/pq/misc.py b/psycopg/psycopg/pq/misc.py new file mode 100644 index 0000000..3a43133 --- /dev/null +++ b/psycopg/psycopg/pq/misc.py @@ -0,0 +1,146 @@ +""" +Various functionalities to make easier to work with the libpq. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import sys +import logging +import ctypes.util +from typing import cast, NamedTuple, Optional, Union + +from .abc import PGconn, PGresult +from ._enums import ConnStatus, TransactionStatus, PipelineStatus +from .._compat import cache +from .._encodings import pgconn_encoding + +logger = logging.getLogger("psycopg.pq") + +OK = ConnStatus.OK + + +class PGnotify(NamedTuple): + relname: bytes + be_pid: int + extra: bytes + + +class ConninfoOption(NamedTuple): + keyword: bytes + envvar: Optional[bytes] + compiled: Optional[bytes] + val: Optional[bytes] + label: bytes + dispchar: bytes + dispsize: int + + +class PGresAttDesc(NamedTuple): + name: bytes + tableid: int + columnid: int + format: int + typid: int + typlen: int + atttypmod: int + + +@cache +def find_libpq_full_path() -> Optional[str]: + if sys.platform == "win32": + libname = ctypes.util.find_library("libpq.dll") + + elif sys.platform == "darwin": + libname = ctypes.util.find_library("libpq.dylib") + # (hopefully) temporary hack: libpq not in a standard place + # https://github.com/orgs/Homebrew/discussions/3595 + # If pg_config is available and agrees, let's use its indications. + if not libname: + try: + import subprocess as sp + + libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode() + libname = os.path.join(libdir, "libpq.dylib") + if not os.path.exists(libname): + libname = None + except Exception as ex: + logger.debug("couldn't use pg_config to find libpq: %s", ex) + + else: + libname = ctypes.util.find_library("pq") + + return libname + + +def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str: + """ + Return an error message from a `PGconn` or `PGresult`. + + The return value is a `!str` (unlike pq data which is usually `!bytes`): + use the connection encoding if available, otherwise the `!encoding` + parameter as a fallback for decoding. Don't raise exceptions on decoding + errors. + + """ + bmsg: bytes + + if hasattr(obj, "error_field"): + # obj is a PGresult + obj = cast(PGresult, obj) + bmsg = obj.error_message + + # strip severity and whitespaces + if bmsg: + bmsg = bmsg.split(b":", 1)[-1].strip() + + elif hasattr(obj, "error_message"): + # obj is a PGconn + if obj.status == OK: + encoding = pgconn_encoding(obj) + bmsg = obj.error_message + + # strip severity and whitespaces + if bmsg: + bmsg = bmsg.split(b":", 1)[-1].strip() + + else: + raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}") + + if bmsg: + msg = bmsg.decode(encoding, "replace") + else: + msg = "no details available" + + return msg + + +def connection_summary(pgconn: PGconn) -> str: + """ + Return summary information on a connection. + + Useful for __repr__ + """ + parts = [] + if pgconn.status == OK: + # Put together the [STATUS] + status = TransactionStatus(pgconn.transaction_status).name + if pgconn.pipeline_status: + status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}" + + # Put together the (CONNECTION) + if not pgconn.host.startswith(b"/"): + parts.append(("host", pgconn.host.decode())) + if pgconn.port != b"5432": + parts.append(("port", pgconn.port.decode())) + if pgconn.user != pgconn.db: + parts.append(("user", pgconn.user.decode())) + parts.append(("database", pgconn.db.decode())) + + else: + status = ConnStatus(pgconn.status).name + + sparts = " ".join("%s=%s" % part for part in parts) + if sparts: + sparts = f" ({sparts})" + return f"[{status}]{sparts}" diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py new file mode 100644 index 0000000..8b87c19 --- /dev/null +++ b/psycopg/psycopg/pq/pq_ctypes.py @@ -0,0 +1,1086 @@ +""" +libpq Python wrapper using ctypes bindings. + +Clients shouldn't use this module directly, unless for testing: they should use +the `pq` module instead, which is in charge of choosing the best +implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import logging +from os import getpid +from weakref import ref + +from ctypes import Array, POINTER, cast, string_at, create_string_buffer, byref +from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object +from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import cast as t_cast, TYPE_CHECKING + +from .. import errors as e +from . import _pq_ctypes as impl +from .misc import PGnotify, ConninfoOption, PGresAttDesc +from .misc import error_message, connection_summary +from ._enums import Format, ExecStatus, Trace + +# Imported locally to call them from __del__ methods +from ._pq_ctypes import PQclear, PQfinish, PQfreeCancel, PQstatus + +if TYPE_CHECKING: + from . import abc + +__impl__ = "python" + +logger = logging.getLogger("psycopg") + + +def version() -> int: + """Return the version number of the libpq currently loaded. + + The number is in the same format of `~psycopg.ConnectionInfo.server_version`. + + Certain features might not be available if the libpq library used is too old. + """ + return impl.PQlibVersion() + + +@impl.PQnoticeReceiver # type: ignore +def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None: + pgconn = cast(arg, POINTER(py_object)).contents.value() + if not (pgconn and pgconn.notice_handler): + return + + res = PGresult(result_ptr) + try: + pgconn.notice_handler(res) + except Exception as exc: + logger.exception("error in notice receiver: %s", exc) + finally: + res._pgresult_ptr = None # avoid destroying the pgresult_ptr + + +class PGconn: + """ + Python representation of a libpq connection. + """ + + __slots__ = ( + "_pgconn_ptr", + "notice_handler", + "notify_handler", + "_self_ptr", + "_procpid", + "__weakref__", + ) + + def __init__(self, pgconn_ptr: impl.PGconn_struct): + self._pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr + self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None + self.notify_handler: Optional[Callable[[PGnotify], None]] = None + + # Keep alive for the lifetime of PGconn + self._self_ptr = py_object(ref(self)) + impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr)) + + self._procpid = getpid() + + def __del__(self) -> None: + # Close the connection only if it was created in this process, + # not if this object is being GC'd after fork. + if getpid() == self._procpid: + self.finish() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self) + return f"<{cls} {info} at 0x{id(self):x}>" + + @classmethod + def connect(cls, conninfo: bytes) -> "PGconn": + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + pgconn_ptr = impl.PQconnectdb(conninfo) + if not pgconn_ptr: + raise MemoryError("couldn't allocate PGconn") + return cls(pgconn_ptr) + + @classmethod + def connect_start(cls, conninfo: bytes) -> "PGconn": + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + pgconn_ptr = impl.PQconnectStart(conninfo) + if not pgconn_ptr: + raise MemoryError("couldn't allocate PGconn") + return cls(pgconn_ptr) + + def connect_poll(self) -> int: + return self._call_int(impl.PQconnectPoll) + + def finish(self) -> None: + self._pgconn_ptr, p = None, self._pgconn_ptr + if p: + PQfinish(p) + + @property + def pgconn_ptr(self) -> Optional[int]: + """The pointer to the underlying `!PGconn` structure, as integer. + + `!None` if the connection is closed. + + The value can be used to pass the structure to libpq functions which + psycopg doesn't (currently) wrap, either in C or in Python using FFI + libraries such as `ctypes`. + """ + if self._pgconn_ptr is None: + return None + + return addressof(self._pgconn_ptr.contents) # type: ignore[attr-defined] + + @property + def info(self) -> List["ConninfoOption"]: + self._ensure_pgconn() + opts = impl.PQconninfo(self._pgconn_ptr) + if not opts: + raise MemoryError("couldn't allocate connection info") + try: + return Conninfo._options_from_array(opts) + finally: + impl.PQconninfoFree(opts) + + def reset(self) -> None: + self._ensure_pgconn() + impl.PQreset(self._pgconn_ptr) + + def reset_start(self) -> None: + if not impl.PQresetStart(self._pgconn_ptr): + raise e.OperationalError("couldn't reset connection") + + def reset_poll(self) -> int: + return self._call_int(impl.PQresetPoll) + + @classmethod + def ping(self, conninfo: bytes) -> int: + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + return impl.PQping(conninfo) + + @property + def db(self) -> bytes: + return self._call_bytes(impl.PQdb) + + @property + def user(self) -> bytes: + return self._call_bytes(impl.PQuser) + + @property + def password(self) -> bytes: + return self._call_bytes(impl.PQpass) + + @property + def host(self) -> bytes: + return self._call_bytes(impl.PQhost) + + @property + def hostaddr(self) -> bytes: + return self._call_bytes(impl.PQhostaddr) + + @property + def port(self) -> bytes: + return self._call_bytes(impl.PQport) + + @property + def tty(self) -> bytes: + return self._call_bytes(impl.PQtty) + + @property + def options(self) -> bytes: + return self._call_bytes(impl.PQoptions) + + @property + def status(self) -> int: + return PQstatus(self._pgconn_ptr) + + @property + def transaction_status(self) -> int: + return impl.PQtransactionStatus(self._pgconn_ptr) + + def parameter_status(self, name: bytes) -> Optional[bytes]: + self._ensure_pgconn() + return impl.PQparameterStatus(self._pgconn_ptr, name) + + @property + def error_message(self) -> bytes: + return impl.PQerrorMessage(self._pgconn_ptr) + + @property + def protocol_version(self) -> int: + return self._call_int(impl.PQprotocolVersion) + + @property + def server_version(self) -> int: + return self._call_int(impl.PQserverVersion) + + @property + def socket(self) -> int: + rv = self._call_int(impl.PQsocket) + if rv == -1: + raise e.OperationalError("the connection is lost") + return rv + + @property + def backend_pid(self) -> int: + return self._call_int(impl.PQbackendPID) + + @property + def needs_password(self) -> bool: + """True if the connection authentication method required a password, + but none was available. + + See :pq:`PQconnectionNeedsPassword` for details. + """ + return bool(impl.PQconnectionNeedsPassword(self._pgconn_ptr)) + + @property + def used_password(self) -> bool: + """True if the connection authentication method used a password. + + See :pq:`PQconnectionUsedPassword` for details. + """ + return bool(impl.PQconnectionUsedPassword(self._pgconn_ptr)) + + @property + def ssl_in_use(self) -> bool: + return self._call_bool(impl.PQsslInUse) + + def exec_(self, command: bytes) -> "PGresult": + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + self._ensure_pgconn() + rv = impl.PQexec(self._pgconn_ptr, command) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_query(self, command: bytes) -> None: + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + self._ensure_pgconn() + if not impl.PQsendQuery(self._pgconn_ptr, command): + raise e.OperationalError(f"sending query failed: {error_message(self)}") + + def exec_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> "PGresult": + args = self._query_params_args( + command, param_values, param_types, param_formats, result_format + ) + self._ensure_pgconn() + rv = impl.PQexecParams(*args) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + args = self._query_params_args( + command, param_values, param_types, param_formats, result_format + ) + self._ensure_pgconn() + if not impl.PQsendQueryParams(*args): + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + atypes: Optional[Array[impl.Oid]] + if not param_types: + nparams = 0 + atypes = None + else: + nparams = len(param_types) + atypes = (impl.Oid * nparams)(*param_types) + + self._ensure_pgconn() + if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes): + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + # repurpose this function with a cheeky replacement of query with name, + # drop the param_types from the result + args = self._query_params_args( + name, param_values, None, param_formats, result_format + ) + args = args[:3] + args[4:] + + self._ensure_pgconn() + if not impl.PQsendQueryPrepared(*args): + raise e.OperationalError( + f"sending prepared query failed: {error_message(self)}" + ) + + def _query_params_args( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> Any: + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + + aparams: Optional[Array[c_char_p]] + alenghts: Optional[Array[c_int]] + if param_values: + nparams = len(param_values) + aparams = (c_char_p * nparams)( + *( + # convert bytearray/memoryview to bytes + b if b is None or isinstance(b, bytes) else bytes(b) + for b in param_values + ) + ) + alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values)) + else: + nparams = 0 + aparams = alenghts = None + + atypes: Optional[Array[impl.Oid]] + if not param_types: + atypes = None + else: + if len(param_types) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(param_types)) + ) + atypes = (impl.Oid * nparams)(*param_types) + + if not param_formats: + aformats = None + else: + if len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_formats" + % (nparams, len(param_formats)) + ) + aformats = (c_int * nparams)(*param_formats) + + return ( + self._pgconn_ptr, + command, + nparams, + atypes, + aparams, + alenghts, + aformats, + result_format, + ) + + def prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + + if not isinstance(command, bytes): + raise TypeError(f"'command' must be bytes, got {type(command)} instead") + + if not param_types: + nparams = 0 + atypes = None + else: + nparams = len(param_types) + atypes = (impl.Oid * nparams)(*param_types) + + self._ensure_pgconn() + rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def exec_prepared( + self, + name: bytes, + param_values: Optional[Sequence["abc.Buffer"]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = 0, + ) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + + aparams: Optional[Array[c_char_p]] + alenghts: Optional[Array[c_int]] + if param_values: + nparams = len(param_values) + aparams = (c_char_p * nparams)( + *( + # convert bytearray/memoryview to bytes + b if b is None or isinstance(b, bytes) else bytes(b) + for b in param_values + ) + ) + alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values)) + else: + nparams = 0 + aparams = alenghts = None + + if not param_formats: + aformats = None + else: + if len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(param_formats)) + ) + aformats = (c_int * nparams)(*param_formats) + + self._ensure_pgconn() + rv = impl.PQexecPrepared( + self._pgconn_ptr, + name, + nparams, + aparams, + alenghts, + aformats, + result_format, + ) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def describe_prepared(self, name: bytes) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + self._ensure_pgconn() + rv = impl.PQdescribePrepared(self._pgconn_ptr, name) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_describe_prepared(self, name: bytes) -> None: + if not isinstance(name, bytes): + raise TypeError(f"bytes expected, got {type(name)} instead") + self._ensure_pgconn() + if not impl.PQsendDescribePrepared(self._pgconn_ptr, name): + raise e.OperationalError( + f"sending describe prepared failed: {error_message(self)}" + ) + + def describe_portal(self, name: bytes) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + self._ensure_pgconn() + rv = impl.PQdescribePortal(self._pgconn_ptr, name) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_describe_portal(self, name: bytes) -> None: + if not isinstance(name, bytes): + raise TypeError(f"bytes expected, got {type(name)} instead") + self._ensure_pgconn() + if not impl.PQsendDescribePortal(self._pgconn_ptr, name): + raise e.OperationalError( + f"sending describe portal failed: {error_message(self)}" + ) + + def get_result(self) -> Optional["PGresult"]: + rv = impl.PQgetResult(self._pgconn_ptr) + return PGresult(rv) if rv else None + + def consume_input(self) -> None: + if 1 != impl.PQconsumeInput(self._pgconn_ptr): + raise e.OperationalError(f"consuming input failed: {error_message(self)}") + + def is_busy(self) -> int: + return impl.PQisBusy(self._pgconn_ptr) + + @property + def nonblocking(self) -> int: + return impl.PQisnonblocking(self._pgconn_ptr) + + @nonblocking.setter + def nonblocking(self, arg: int) -> None: + if 0 > impl.PQsetnonblocking(self._pgconn_ptr, arg): + raise e.OperationalError( + f"setting nonblocking failed: {error_message(self)}" + ) + + def flush(self) -> int: + # PQflush segfaults if it receives a NULL connection + if not self._pgconn_ptr: + raise e.OperationalError("flushing failed: the connection is closed") + rv: int = impl.PQflush(self._pgconn_ptr) + if rv < 0: + raise e.OperationalError(f"flushing failed: {error_message(self)}") + return rv + + def set_single_row_mode(self) -> None: + if not impl.PQsetSingleRowMode(self._pgconn_ptr): + raise e.OperationalError("setting single row mode failed") + + def get_cancel(self) -> "PGcancel": + """ + Create an object with the information needed to cancel a command. + + See :pq:`PQgetCancel` for details. + """ + rv = impl.PQgetCancel(self._pgconn_ptr) + if not rv: + raise e.OperationalError("couldn't create cancel object") + return PGcancel(rv) + + def notifies(self) -> Optional[PGnotify]: + ptr = impl.PQnotifies(self._pgconn_ptr) + if ptr: + c = ptr.contents + return PGnotify(c.relname, c.be_pid, c.extra) + impl.PQfreemem(ptr) + else: + return None + + def put_copy_data(self, buffer: "abc.Buffer") -> int: + if not isinstance(buffer, bytes): + buffer = bytes(buffer) + rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer)) + if rv < 0: + raise e.OperationalError(f"sending copy data failed: {error_message(self)}") + return rv + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + rv = impl.PQputCopyEnd(self._pgconn_ptr, error) + if rv < 0: + raise e.OperationalError(f"sending copy end failed: {error_message(self)}") + return rv + + def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: + buffer_ptr = c_char_p() + nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_) + if nbytes == -2: + raise e.OperationalError( + f"receiving copy data failed: {error_message(self)}" + ) + if buffer_ptr: + # TODO: do it without copy + data = string_at(buffer_ptr, nbytes) + impl.PQfreemem(buffer_ptr) + return nbytes, memoryview(data) + else: + return nbytes, memoryview(b"") + + def trace(self, fileno: int) -> None: + """ + Enable tracing of the client/server communication to a file stream. + + See :pq:`PQtrace` for details. + """ + if sys.platform != "linux": + raise e.NotSupportedError("currently only supported on Linux") + stream = impl.fdopen(fileno, b"w") + impl.PQtrace(self._pgconn_ptr, stream) + + def set_trace_flags(self, flags: Trace) -> None: + """ + Configure tracing behavior of client/server communication. + + :param flags: operating mode of tracing. + + See :pq:`PQsetTraceFlags` for details. + """ + impl.PQsetTraceFlags(self._pgconn_ptr, flags) + + def untrace(self) -> None: + """ + Disable tracing, previously enabled through `trace()`. + + See :pq:`PQuntrace` for details. + """ + impl.PQuntrace(self._pgconn_ptr) + + def encrypt_password( + self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None + ) -> bytes: + """ + Return the encrypted form of a PostgreSQL password. + + See :pq:`PQencryptPasswordConn` for details. + """ + out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm) + if not out: + raise e.OperationalError( + f"password encryption failed: {error_message(self)}" + ) + + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def make_empty_result(self, exec_status: int) -> "PGresult": + rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status) + if not rv: + raise MemoryError("couldn't allocate empty PGresult") + return PGresult(rv) + + @property + def pipeline_status(self) -> int: + if version() < 140000: + return 0 + return impl.PQpipelineStatus(self._pgconn_ptr) + + def enter_pipeline_mode(self) -> None: + """Enter pipeline mode. + + :raises ~e.OperationalError: in case of failure to enter the pipeline + mode. + """ + if impl.PQenterPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError("failed to enter pipeline mode") + + def exit_pipeline_mode(self) -> None: + """Exit pipeline mode. + + :raises ~e.OperationalError: in case of failure to exit the pipeline + mode. + """ + if impl.PQexitPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError(error_message(self)) + + def pipeline_sync(self) -> None: + """Mark a synchronization point in a pipeline. + + :raises ~e.OperationalError: if the connection is not in pipeline mode + or if sync failed. + """ + rv = impl.PQpipelineSync(self._pgconn_ptr) + if rv == 0: + raise e.OperationalError("connection not in pipeline mode") + if rv != 1: + raise e.OperationalError("failed to sync pipeline") + + def send_flush_request(self) -> None: + """Sends a request for the server to flush its output buffer. + + :raises ~e.OperationalError: if the flush request failed. + """ + if impl.PQsendFlushRequest(self._pgconn_ptr) == 0: + raise e.OperationalError(f"flush request failed: {error_message(self)}") + + def _call_bytes( + self, func: Callable[[impl.PGconn_struct], Optional[bytes]] + ) -> bytes: + """ + Call one of the pgconn libpq functions returning a bytes pointer. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + rv = func(self._pgconn_ptr) + assert rv is not None + return rv + + def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int: + """ + Call one of the pgconn libpq functions returning an int. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + return func(self._pgconn_ptr) + + def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool: + """ + Call one of the pgconn libpq functions returning a logical value. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + return bool(func(self._pgconn_ptr)) + + def _ensure_pgconn(self) -> None: + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + + +class PGresult: + """ + Python representation of a libpq result. + """ + + __slots__ = ("_pgresult_ptr",) + + def __init__(self, pgresult_ptr: impl.PGresult_struct): + self._pgresult_ptr: Optional[impl.PGresult_struct] = pgresult_ptr + + def __del__(self) -> None: + self.clear() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + status = ExecStatus(self.status) + return f"<{cls} [{status.name}] at 0x{id(self):x}>" + + def clear(self) -> None: + self._pgresult_ptr, p = None, self._pgresult_ptr + if p: + PQclear(p) + + @property + def pgresult_ptr(self) -> Optional[int]: + """The pointer to the underlying `!PGresult` structure, as integer. + + `!None` if the result was cleared. + + The value can be used to pass the structure to libpq functions which + psycopg doesn't (currently) wrap, either in C or in Python using FFI + libraries such as `ctypes`. + """ + if self._pgresult_ptr is None: + return None + + return addressof(self._pgresult_ptr.contents) # type: ignore[attr-defined] + + @property + def status(self) -> int: + return impl.PQresultStatus(self._pgresult_ptr) + + @property + def error_message(self) -> bytes: + return impl.PQresultErrorMessage(self._pgresult_ptr) + + def error_field(self, fieldcode: int) -> Optional[bytes]: + return impl.PQresultErrorField(self._pgresult_ptr, fieldcode) + + @property + def ntuples(self) -> int: + return impl.PQntuples(self._pgresult_ptr) + + @property + def nfields(self) -> int: + return impl.PQnfields(self._pgresult_ptr) + + def fname(self, column_number: int) -> Optional[bytes]: + return impl.PQfname(self._pgresult_ptr, column_number) + + def ftable(self, column_number: int) -> int: + return impl.PQftable(self._pgresult_ptr, column_number) + + def ftablecol(self, column_number: int) -> int: + return impl.PQftablecol(self._pgresult_ptr, column_number) + + def fformat(self, column_number: int) -> int: + return impl.PQfformat(self._pgresult_ptr, column_number) + + def ftype(self, column_number: int) -> int: + return impl.PQftype(self._pgresult_ptr, column_number) + + def fmod(self, column_number: int) -> int: + return impl.PQfmod(self._pgresult_ptr, column_number) + + def fsize(self, column_number: int) -> int: + return impl.PQfsize(self._pgresult_ptr, column_number) + + @property + def binary_tuples(self) -> int: + return impl.PQbinaryTuples(self._pgresult_ptr) + + def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: + length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number) + if length: + v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number) + return string_at(v, length) + else: + if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number): + return None + else: + return b"" + + @property + def nparams(self) -> int: + return impl.PQnparams(self._pgresult_ptr) + + def param_type(self, param_number: int) -> int: + return impl.PQparamtype(self._pgresult_ptr, param_number) + + @property + def command_status(self) -> Optional[bytes]: + return impl.PQcmdStatus(self._pgresult_ptr) + + @property + def command_tuples(self) -> Optional[int]: + rv = impl.PQcmdTuples(self._pgresult_ptr) + return int(rv) if rv else None + + @property + def oid_value(self) -> int: + return impl.PQoidValue(self._pgresult_ptr) + + def set_attributes(self, descriptions: List[PGresAttDesc]) -> None: + structs = [ + impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore + ] + array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore + rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array) + if rv == 0: + raise e.OperationalError("PQsetResultAttrs failed") + + +class PGcancel: + """ + Token to cancel the current operation on a connection. + + Created by `PGconn.get_cancel()`. + """ + + __slots__ = ("pgcancel_ptr",) + + def __init__(self, pgcancel_ptr: impl.PGcancel_struct): + self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr + + def __del__(self) -> None: + self.free() + + def free(self) -> None: + """ + Free the data structure created by :pq:`PQgetCancel()`. + + Automatically invoked by `!__del__()`. + + See :pq:`PQfreeCancel()` for details. + """ + self.pgcancel_ptr, p = None, self.pgcancel_ptr + if p: + PQfreeCancel(p) + + def cancel(self) -> None: + """Requests that the server abandon processing of the current command. + + See :pq:`PQcancel()` for details. + """ + buf = create_string_buffer(256) + res = impl.PQcancel( + self.pgcancel_ptr, + byref(buf), # type: ignore[arg-type] + len(buf), + ) + if not res: + raise e.OperationalError( + f"cancel failed: {buf.value.decode('utf8', 'ignore')}" + ) + + +class Conninfo: + """ + Utility object to manipulate connection strings. + """ + + @classmethod + def get_defaults(cls) -> List[ConninfoOption]: + opts = impl.PQconndefaults() + if not opts: + raise MemoryError("couldn't allocate connection defaults") + try: + return cls._options_from_array(opts) + finally: + impl.PQconninfoFree(opts) + + @classmethod + def parse(cls, conninfo: bytes) -> List[ConninfoOption]: + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + errmsg = c_char_p() + rv = impl.PQconninfoParse(conninfo, byref(errmsg)) # type: ignore[arg-type] + if not rv: + if not errmsg: + raise MemoryError("couldn't allocate on conninfo parse") + else: + exc = e.OperationalError( + (errmsg.value or b"").decode("utf8", "replace") + ) + impl.PQfreemem(errmsg) + raise exc + + try: + return cls._options_from_array(rv) + finally: + impl.PQconninfoFree(rv) + + @classmethod + def _options_from_array( + cls, opts: Sequence[impl.PQconninfoOption_struct] + ) -> List[ConninfoOption]: + rv = [] + skws = "keyword envvar compiled val label dispchar".split() + for opt in opts: + if not opt.keyword: + break + d = {kw: getattr(opt, kw) for kw in skws} + d["dispsize"] = opt.dispsize + rv.append(ConninfoOption(**d)) + + return rv + + +class Escaping: + """ + Utility object to escape strings for SQL interpolation. + """ + + def __init__(self, conn: Optional[PGconn] = None): + self.conn = conn + + def escape_literal(self, data: "abc.Buffer") -> bytes: + if not self.conn: + raise e.OperationalError("escape_literal failed: no connection provided") + + self.conn._ensure_pgconn() + # TODO: might be done without copy (however C does that) + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data)) + if not out: + raise e.OperationalError( + f"escape_literal failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def escape_identifier(self, data: "abc.Buffer") -> bytes: + if not self.conn: + raise e.OperationalError("escape_identifier failed: no connection provided") + + self.conn._ensure_pgconn() + + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data)) + if not out: + raise e.OperationalError( + f"escape_identifier failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def escape_string(self, data: "abc.Buffer") -> bytes: + if not isinstance(data, bytes): + data = bytes(data) + + if self.conn: + self.conn._ensure_pgconn() + error = c_int() + out = create_string_buffer(len(data) * 2 + 1) + impl.PQescapeStringConn( + self.conn._pgconn_ptr, + byref(out), # type: ignore[arg-type] + data, + len(data), + byref(error), # type: ignore[arg-type] + ) + + if error: + raise e.OperationalError( + f"escape_string failed: {error_message(self.conn)} bytes" + ) + + else: + out = create_string_buffer(len(data) * 2 + 1) + impl.PQescapeString( + byref(out), # type: ignore[arg-type] + data, + len(data), + ) + + return out.value + + def escape_bytea(self, data: "abc.Buffer") -> bytes: + len_out = c_size_t() + # TODO: might be able to do without a copy but it's a mess. + # the C library does it better anyway, so maybe not worth optimising + # https://mail.python.org/pipermail/python-dev/2012-September/121780.html + if not isinstance(data, bytes): + data = bytes(data) + if self.conn: + self.conn._ensure_pgconn() + out = impl.PQescapeByteaConn( + self.conn._pgconn_ptr, + data, + len(data), + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + else: + out = impl.PQescapeBytea( + data, + len(data), + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + if not out: + raise MemoryError( + f"couldn't allocate for escape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value - 1) # out includes final 0 + impl.PQfreemem(out) + return rv + + def unescape_bytea(self, data: "abc.Buffer") -> bytes: + # not needed, but let's keep it symmetric with the escaping: + # if a connection is passed in, it must be valid. + if self.conn: + self.conn._ensure_pgconn() + + len_out = c_size_t() + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQunescapeBytea( + data, + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + if not out: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value) + impl.PQfreemem(out) + return rv + + +# importing the ssl module sets up Python's libcrypto callbacks +import ssl # noqa + +# disable libcrypto setup in libpq, so it won't stomp on the callbacks +# that have already been set up +impl.PQinitOpenSSL(1, 0) + +__build_version__ = version() diff --git a/psycopg/psycopg/py.typed b/psycopg/psycopg/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/psycopg/psycopg/py.typed diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py new file mode 100644 index 0000000..cb28b57 --- /dev/null +++ b/psycopg/psycopg/rows.py @@ -0,0 +1,256 @@ +""" +psycopg row factories +""" + +# Copyright (C) 2021 The Psycopg Team + +import functools +from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn +from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar +from collections import namedtuple +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from ._compat import Protocol +from ._encodings import _as_python_identifier + +if TYPE_CHECKING: + from .cursor import BaseCursor, Cursor + from .cursor_async import AsyncCursor + from psycopg.pq.abc import PGresult + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE + +T = TypeVar("T", covariant=True) + +# Row factories + +Row = TypeVar("Row", covariant=True) + + +class RowMaker(Protocol[Row]): + """ + Callable protocol taking a sequence of value and returning an object. + + The sequence of value is what is returned from a database query, already + adapted to the right Python types. The return value is the object that your + program would like to receive: by default (`tuple_row()`) it is a simple + tuple, but it may be any type of object. + + Typically, `!RowMaker` functions are returned by `RowFactory`. + """ + + def __call__(self, __values: Sequence[Any]) -> Row: + ... + + +class RowFactory(Protocol[Row]): + """ + Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`. + + A `!RowFactory` is typically called when a `!Cursor` receives a result. + This way it can inspect the cursor state (for instance the + `~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create + a complete object. + + For instance the `dict_row()` `!RowFactory` uses the names of the column to + define the dictionary key and returns a `!RowMaker` function which would + use the values to create a dictionary for each record. + """ + + def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: + ... + + +class AsyncRowFactory(Protocol[Row]): + """ + Like `RowFactory`, taking an async cursor as argument. + """ + + def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: + ... + + +class BaseRowFactory(Protocol[Row]): + """ + Like `RowFactory`, taking either type of cursor as argument. + """ + + def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: + ... + + +TupleRow: TypeAlias = Tuple[Any, ...] +""" +An alias for the type returned by `tuple_row()` (i.e. a tuple of any content). +""" + + +DictRow: TypeAlias = Dict[str, Any] +""" +An alias for the type returned by `dict_row()` + +A `!DictRow` is a dictionary with keys as string and any value returned by the +database. +""" + + +def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]": + r"""Row factory to represent rows as simple tuples. + + This is the default factory, used when `~psycopg.Connection.connect()` or + `~psycopg.Connection.cursor()` are called without a `!row_factory` + parameter. + + """ + # Implementation detail: make sure this is the tuple type itself, not an + # equivalent function, because the C code fast-paths on it. + return tuple + + +def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]": + """Row factory to represent rows as dictionaries. + + The dictionary keys are taken from the column names of the returned columns. + """ + names = _get_names(cursor) + if names is None: + return no_result + + def dict_row_(values: Sequence[Any]) -> Dict[str, Any]: + # https://github.com/python/mypy/issues/2608 + return dict(zip(names, values)) # type: ignore[arg-type] + + return dict_row_ + + +def namedtuple_row( + cursor: "BaseCursor[Any, Any]", +) -> "RowMaker[NamedTuple]": + """Row factory to represent rows as `~collections.namedtuple`. + + The field names are taken from the column names of the returned columns, + with some mangling to deal with invalid names. + """ + res = cursor.pgresult + if not res: + return no_result + + nfields = _get_nfields(res) + if nfields is None: + return no_result + + nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields))) + return nt._make + + +@functools.lru_cache(512) +def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]: + snames = tuple(_as_python_identifier(n.decode(enc)) for n in names) + return namedtuple("Row", snames) # type: ignore[return-value] + + +def class_row(cls: Type[T]) -> BaseRowFactory[T]: + r"""Generate a row factory to represent rows as instances of the class `!cls`. + + The class must support every output column name as a keyword parameter. + + :param cls: The class to return for each row. It must support the fields + returned by the query as keyword arguments. + :rtype: `!Callable[[Cursor],` `RowMaker`\[~T]] + """ + + def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]": + names = _get_names(cursor) + if names is None: + return no_result + + def class_row__(values: Sequence[Any]) -> T: + return cls(**dict(zip(names, values))) # type: ignore[arg-type] + + return class_row__ + + return class_row_ + + +def args_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling `!func` with positional parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as positional arguments. + """ + + def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]": + def args_row__(values: Sequence[Any]) -> T: + return func(*values) + + return args_row__ + + return args_row_ + + +def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling `!func` with keyword parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as keyword arguments. + """ + + def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]": + names = _get_names(cursor) + if names is None: + return no_result + + def kwargs_row__(values: Sequence[Any]) -> T: + return func(**dict(zip(names, values))) # type: ignore[arg-type] + + return kwargs_row__ + + return kwargs_row_ + + +def no_result(values: Sequence[Any]) -> NoReturn: + """A `RowMaker` that always fail. + + It can be used as return value for a `RowFactory` called with no result. + Note that the `!RowFactory` *will* be called with no result, but the + resulting `!RowMaker` never should. + """ + raise e.InterfaceError("the cursor doesn't have a result") + + +def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]: + res = cursor.pgresult + if not res: + return None + + nfields = _get_nfields(res) + if nfields is None: + return None + + enc = cursor._encoding + return [ + res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr] + ] + + +def _get_nfields(res: "PGresult") -> Optional[int]: + """ + Return the number of columns in a result, if it returns tuples else None + + Take into account the special case of results with zero columns. + """ + nfields = res.nfields + + if ( + res.status == TUPLES_OK + or res.status == SINGLE_TUPLE + # "describe" in named cursors + or (res.status == COMMAND_OK and nfields) + ): + return nfields + else: + return None diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py new file mode 100644 index 0000000..b890d77 --- /dev/null +++ b/psycopg/psycopg/server_cursor.py @@ -0,0 +1,479 @@ +""" +psycopg server-side cursor objects. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, AsyncIterator, List, Iterable, Iterator +from typing import Optional, TypeVar, TYPE_CHECKING, overload +from warnings import warn + +from . import pq +from . import sql +from . import errors as e +from .abc import ConnectionType, Query, Params, PQGen +from .rows import Row, RowFactory, AsyncRowFactory +from .cursor import BaseCursor, Cursor +from .generators import execute +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from .connection import Connection + from .connection_async import AsyncConnection + +DEFAULT_ITERSIZE = 100 + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + + +class ServerCursorMixin(BaseCursor[ConnectionType, Row]): + """Mixin to add ServerCursor behaviour and implementation a BaseCursor.""" + + __slots__ = "_name _scrollable _withhold _described itersize _format".split() + + def __init__( + self, + name: str, + scrollable: Optional[bool], + withhold: bool, + ): + self._name = name + self._scrollable = scrollable + self._withhold = withhold + self._described = False + self.itersize: int = DEFAULT_ITERSIZE + self._format = TEXT + + def __repr__(self) -> str: + # Insert the name as the second word + parts = super().__repr__().split(None, 1) + parts.insert(1, f"{self._name!r}") + return " ".join(parts) + + @property + def name(self) -> str: + """The name of the cursor.""" + return self._name + + @property + def scrollable(self) -> Optional[bool]: + """ + Whether the cursor is scrollable or not. + + If `!None` leave the choice to the server. Use `!True` if you want to + use `scroll()` on the cursor. + """ + return self._scrollable + + @property + def withhold(self) -> bool: + """ + If the cursor can be used after the creating transaction has committed. + """ + return self._withhold + + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + res = self.pgresult + # command_status is empty if the result comes from + # describe_portal, which means that we have just executed the DECLARE, + # so we can assume we are at the first row. + tuples = res and (res.status == TUPLES_OK or res.command_status == b"") + return self._pos if tuples else None + + def _declare_gen( + self, + query: Query, + params: Optional[Params] = None, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator implementing `ServerCursor.execute()`.""" + + query = self._make_declare_statement(query) + + # If the cursor is being reused, the previous one must be closed. + if self._described: + yield from self._close_gen() + self._described = False + + yield from self._start_query(query) + pgq = self._convert_query(query, params) + self._execute_send(pgq, force_extended=True) + results = yield from execute(self._conn.pgconn) + if results[-1].status != COMMAND_OK: + self._raise_for_result(results[-1]) + + # Set the format, which will be used by describe and fetch operations + if binary is None: + self._format = self.format + else: + self._format = BINARY if binary else TEXT + + # The above result only returned COMMAND_OK. Get the cursor shape + yield from self._describe_gen() + + def _describe_gen(self) -> PQGen[None]: + self._pgconn.send_describe_portal(self._name.encode(self._encoding)) + results = yield from execute(self._pgconn) + self._check_results(results) + self._results = results + self._select_current_result(0, format=self._format) + self._described = True + + def _close_gen(self) -> PQGen[None]: + ts = self._conn.pgconn.transaction_status + + # if the connection is not in a sane state, don't even try + if ts != IDLE and ts != INTRANS: + return + + # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already. + if not self._withhold and ts == IDLE: + return + + # if we didn't declare the cursor ourselves we still have to close it + # but we must make sure it exists. + if not self._described: + query = sql.SQL( + "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}" + ).format(sql.Literal(self._name)) + res = yield from self._conn._exec_command(query) + # pipeline mode otherwise, unsupported here. + assert res is not None + if res.ntuples == 0: + return + + query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name)) + yield from self._conn._exec_command(query) + + def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]: + if self.closed: + raise e.InterfaceError("the cursor is closed") + # If we are stealing the cursor, make sure we know its shape + if not self._described: + yield from self._start_query() + yield from self._describe_gen() + + query = sql.SQL("FETCH FORWARD {} FROM {}").format( + sql.SQL("ALL") if num is None else sql.Literal(num), + sql.Identifier(self._name), + ) + res = yield from self._conn._exec_command(query, result_format=self._format) + # pipeline mode otherwise, unsupported here. + assert res is not None + + self.pgresult = res + self._tx.set_pgresult(res, set_loaders=False) + return self._tx.load_rows(0, res.ntuples, self._make_row) + + def _scroll_gen(self, value: int, mode: str) -> PQGen[None]: + if mode not in ("relative", "absolute"): + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + query = sql.SQL("MOVE{} {} FROM {}").format( + sql.SQL(" ABSOLUTE" if mode == "absolute" else ""), + sql.Literal(value), + sql.Identifier(self._name), + ) + yield from self._conn._exec_command(query) + + def _make_declare_statement(self, query: Query) -> sql.Composed: + + if isinstance(query, bytes): + query = query.decode(self._encoding) + if not isinstance(query, sql.Composable): + query = sql.SQL(query) + + parts = [ + sql.SQL("DECLARE"), + sql.Identifier(self._name), + ] + if self._scrollable is not None: + parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL")) + parts.append(sql.SQL("CURSOR")) + if self._withhold: + parts.append(sql.SQL("WITH HOLD")) + parts.append(sql.SQL("FOR")) + parts.append(query) + + return sql.SQL(" ").join(parts) + + +class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="ServerCursor[Any]") + + @overload + def __init__( + self: "ServerCursor[Row]", + connection: "Connection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "ServerCursor[Row]", + connection: "Connection[Any]", + name: str, + *, + row_factory: RowFactory[Row], + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + def __init__( + self, + connection: "Connection[Any]", + name: str, + *, + row_factory: Optional[RowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + Cursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) + + def __del__(self) -> None: + if not self.closed: + warn( + f"the server-side cursor {self} was deleted while still open." + " Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + with self._conn.lock: + if self.closed: + return + if not self._conn.closed: + self._conn.wait(self._close_gen()) + super().close() + + def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + **kwargs: Any, + ) -> _Self: + """ + Open a cursor to execute a query to the database. + """ + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + if self._pgconn.pipeline_status: + raise e.NotSupportedError( + "server-side cursors not supported in pipeline mode" + ) + + try: + with self._conn.lock: + self._conn.wait(self._declare_gen(query, params, binary)) + except e.Error as ex: + raise ex.with_traceback(None) + + return self + + def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = True, + ) -> None: + """Method not implemented for server-side cursors.""" + raise e.NotSupportedError("executemany not supported on server-side cursors") + + def fetchone(self) -> Optional[Row]: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + def fetchmany(self, size: int = 0) -> List[Row]: + if not size: + size = self.arraysize + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(size)) + self._pos += len(recs) + return recs + + def fetchall(self) -> List[Row]: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(None)) + self._pos += len(recs) + return recs + + def __iter__(self) -> Iterator[Row]: + while True: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(self.itersize)) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + def scroll(self, value: int, mode: str = "relative") -> None: + with self._conn.lock: + self._conn.wait(self._scroll_gen(value, mode)) + # Postgres doesn't have a reliable way to report a cursor out of bound + if mode == "relative": + self._pos += value + else: + self._pos = value + + +class AsyncServerCursor( + ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row] +): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="AsyncServerCursor[Any]") + + @overload + def __init__( + self: "AsyncServerCursor[Row]", + connection: "AsyncConnection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "AsyncServerCursor[Row]", + connection: "AsyncConnection[Any]", + name: str, + *, + row_factory: AsyncRowFactory[Row], + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + name: str, + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + AsyncCursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) + + def __del__(self) -> None: + if not self.closed: + warn( + f"the server-side cursor {self} was deleted while still open." + " Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + async def close(self) -> None: + async with self._conn.lock: + if self.closed: + return + if not self._conn.closed: + await self._conn.wait(self._close_gen()) + await super().close() + + async def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + **kwargs: Any, + ) -> _Self: + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + if self._pgconn.pipeline_status: + raise e.NotSupportedError( + "server-side cursors not supported in pipeline mode" + ) + + try: + async with self._conn.lock: + await self._conn.wait(self._declare_gen(query, params, binary)) + except e.Error as ex: + raise ex.with_traceback(None) + + return self + + async def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = True, + ) -> None: + raise e.NotSupportedError("executemany not supported on server-side cursors") + + async def fetchone(self) -> Optional[Row]: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + async def fetchmany(self, size: int = 0) -> List[Row]: + if not size: + size = self.arraysize + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(size)) + self._pos += len(recs) + return recs + + async def fetchall(self) -> List[Row]: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(None)) + self._pos += len(recs) + return recs + + async def __aiter__(self) -> AsyncIterator[Row]: + while True: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(self.itersize)) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + async def scroll(self, value: int, mode: str = "relative") -> None: + async with self._conn.lock: + await self._conn.wait(self._scroll_gen(value, mode)) diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py new file mode 100644 index 0000000..099a01c --- /dev/null +++ b/psycopg/psycopg/sql.py @@ -0,0 +1,467 @@ +""" +SQL composition utility module +""" + +# Copyright (C) 2020 The Psycopg Team + +import codecs +import string +from abc import ABC, abstractmethod +from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union + +from .pq import Escaping +from .abc import AdaptContext +from .adapt import Transformer, PyFormat +from ._compat import LiteralString +from ._encodings import conn_encoding + + +def quote(obj: Any, context: Optional[AdaptContext] = None) -> str: + """ + Adapt a Python object to a quoted SQL string. + + Use this function only if you absolutely want to convert a Python string to + an SQL quoted literal to use e.g. to generate batch SQL and you won't have + a connection available when you will need to use it. + + This function is relatively inefficient, because it doesn't cache the + adaptation rules. If you pass a `!context` you can adapt the adaptation + rules used, otherwise only global rules are used. + + """ + return Literal(obj).as_string(context) + + +class Composable(ABC): + """ + Abstract base class for objects that can be used to compose an SQL string. + + `!Composable` objects can be passed directly to + `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`, + `~psycopg.Cursor.copy()` in place of the query string. + + `!Composable` objects can be joined using the ``+`` operator: the result + will be a `Composed` instance containing the objects joined. The operator + ``*`` is also supported with an integer argument: the result is a + `!Composed` instance containing the left argument repeated as many times as + requested. + """ + + def __init__(self, obj: Any): + self._obj = obj + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._obj!r})" + + @abstractmethod + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + """ + Return the value of the object as bytes. + + :param context: the context to evaluate the object into. + :type context: `connection` or `cursor` + + The method is automatically invoked by `~psycopg.Cursor.execute()`, + `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a + `!Composable` is passed instead of the query string. + + """ + raise NotImplementedError + + def as_string(self, context: Optional[AdaptContext]) -> str: + """ + Return the value of the object as string. + + :param context: the context to evaluate the string into. + :type context: `connection` or `cursor` + + """ + conn = context.connection if context else None + enc = conn_encoding(conn) + b = self.as_bytes(context) + if isinstance(b, bytes): + return b.decode(enc) + else: + # buffer object + return codecs.lookup(enc).decode(b)[0] + + def __add__(self, other: "Composable") -> "Composed": + if isinstance(other, Composed): + return Composed([self]) + other + if isinstance(other, Composable): + return Composed([self]) + Composed([other]) + else: + return NotImplemented + + def __mul__(self, n: int) -> "Composed": + return Composed([self] * n) + + def __eq__(self, other: Any) -> bool: + return type(self) is type(other) and self._obj == other._obj + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +class Composed(Composable): + """ + A `Composable` object made of a sequence of `!Composable`. + + The object is usually created using `!Composable` operators and methods. + However it is possible to create a `!Composed` directly specifying a + sequence of objects as arguments: if they are not `!Composable` they will + be wrapped in a `Literal`. + + Example:: + + >>> comp = sql.Composed( + ... [sql.SQL("INSERT INTO "), sql.Identifier("table")]) + >>> print(comp.as_string(conn)) + INSERT INTO "table" + + `!Composed` objects are iterable (so they can be used in `SQL.join` for + instance). + """ + + _obj: List[Composable] + + def __init__(self, seq: Sequence[Any]): + seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq] + super().__init__(seq) + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + return b"".join(obj.as_bytes(context) for obj in self._obj) + + def __iter__(self) -> Iterator[Composable]: + return iter(self._obj) + + def __add__(self, other: Composable) -> "Composed": + if isinstance(other, Composed): + return Composed(self._obj + other._obj) + if isinstance(other, Composable): + return Composed(self._obj + [other]) + else: + return NotImplemented + + def join(self, joiner: Union["SQL", LiteralString]) -> "Composed": + """ + Return a new `!Composed` interposing the `!joiner` with the `!Composed` items. + + The `!joiner` must be a `SQL` or a string which will be interpreted as + an `SQL`. + + Example:: + + >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed + >>> print(fields.join(', ').as_string(conn)) + "foo", "bar" + + """ + if isinstance(joiner, str): + joiner = SQL(joiner) + elif not isinstance(joiner, SQL): + raise TypeError( + "Composed.join() argument must be strings or SQL," + f" got {joiner!r} instead" + ) + + return joiner.join(self._obj) + + +class SQL(Composable): + """ + A `Composable` representing a snippet of SQL statement. + + `!SQL` exposes `join()` and `format()` methods useful to create a template + where to merge variable parts of a query (for instance field or table + names). + + The `!obj` string doesn't undergo any form of escaping, so it is not + suitable to represent variable identifiers or values: you should only use + it to pass constant strings representing templates or snippets of SQL + statements; use other objects such as `Identifier` or `Literal` to + represent variable parts. + + Example:: + + >>> query = sql.SQL("SELECT {0} FROM {1}").format( + ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), + ... sql.Identifier('table')) + >>> print(query.as_string(conn)) + SELECT "foo", "bar" FROM "table" + """ + + _obj: LiteralString + _formatter = string.Formatter() + + def __init__(self, obj: LiteralString): + super().__init__(obj) + if not isinstance(obj, str): + raise TypeError(f"SQL values must be strings, got {obj!r} instead") + + def as_string(self, context: Optional[AdaptContext]) -> str: + return self._obj + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + enc = "utf-8" + if context: + enc = conn_encoding(context.connection) + return self._obj.encode(enc) + + def format(self, *args: Any, **kwargs: Any) -> Composed: + """ + Merge `Composable` objects into a template. + + :param args: parameters to replace to numbered (``{0}``, ``{1}``) or + auto-numbered (``{}``) placeholders + :param kwargs: parameters to replace to named (``{name}``) placeholders + :return: the union of the `!SQL` string with placeholders replaced + :rtype: `Composed` + + The method is similar to the Python `str.format()` method: the string + template supports auto-numbered (``{}``), numbered (``{0}``, + ``{1}``...), and named placeholders (``{name}``), with positional + arguments replacing the numbered placeholders and keywords replacing + the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``) + are not supported. + + If a `!Composable` objects is passed to the template it will be merged + according to its `as_string()` method. If any other Python object is + passed, it will be wrapped in a `Literal` object and so escaped + according to SQL rules. + + Example:: + + >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s") + ... .format(sql.Identifier('people'), sql.Identifier('id')) + ... .as_string(conn)) + SELECT * FROM "people" WHERE "id" = %s + + >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}") + ... .format(tbl=sql.Identifier('people'), name="O'Rourke")) + ... .as_string(conn)) + SELECT * FROM "people" WHERE name = 'O''Rourke' + + """ + rv: List[Composable] = [] + autonum: Optional[int] = 0 + # TODO: this is probably not the right way to whitelist pre + # pyre complains. Will wait for mypy to complain too to fix. + pre: LiteralString + for pre, name, spec, conv in self._formatter.parse(self._obj): + if spec: + raise ValueError("no format specification supported by SQL") + if conv: + raise ValueError("no format conversion supported by SQL") + if pre: + rv.append(SQL(pre)) + + if name is None: + continue + + if name.isdigit(): + if autonum: + raise ValueError( + "cannot switch from automatic field numbering to manual" + ) + rv.append(args[int(name)]) + autonum = None + + elif not name: + if autonum is None: + raise ValueError( + "cannot switch from manual field numbering to automatic" + ) + rv.append(args[autonum]) + autonum += 1 + + else: + rv.append(kwargs[name]) + + return Composed(rv) + + def join(self, seq: Iterable[Composable]) -> Composed: + """ + Join a sequence of `Composable`. + + :param seq: the elements to join. + :type seq: iterable of `!Composable` + + Use the `!SQL` object's string to separate the elements in `!seq`. + Note that `Composed` objects are iterable too, so they can be used as + argument for this method. + + Example:: + + >>> snip = sql.SQL(', ').join( + ... sql.Identifier(n) for n in ['foo', 'bar', 'baz']) + >>> print(snip.as_string(conn)) + "foo", "bar", "baz" + """ + rv = [] + it = iter(seq) + try: + rv.append(next(it)) + except StopIteration: + pass + else: + for i in it: + rv.append(self) + rv.append(i) + + return Composed(rv) + + +class Identifier(Composable): + """ + A `Composable` representing an SQL identifier or a dot-separated sequence. + + Identifiers usually represent names of database objects, such as tables or + fields. PostgreSQL identifiers follow `different rules`__ than SQL string + literals for escaping (e.g. they use double quotes instead of single). + + .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \ + SQL-SYNTAX-IDENTIFIERS + + Example:: + + >>> t1 = sql.Identifier("foo") + >>> t2 = sql.Identifier("ba'r") + >>> t3 = sql.Identifier('ba"z') + >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn)) + "foo", "ba'r", "ba""z" + + Multiple strings can be passed to the object to represent a qualified name, + i.e. a dot-separated sequence of identifiers. + + Example:: + + >>> query = sql.SQL("SELECT {} FROM {}").format( + ... sql.Identifier("table", "field"), + ... sql.Identifier("schema", "table")) + >>> print(query.as_string(conn)) + SELECT "table"."field" FROM "schema"."table" + + """ + + _obj: Sequence[str] + + def __init__(self, *strings: str): + # init super() now to make the __repr__ not explode in case of error + super().__init__(strings) + + if not strings: + raise TypeError("Identifier cannot be empty") + + for s in strings: + if not isinstance(s, str): + raise TypeError( + f"SQL identifier parts must be strings, got {s!r} instead" + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + conn = context.connection if context else None + if not conn: + raise ValueError("a connection is necessary for Identifier") + esc = Escaping(conn.pgconn) + enc = conn_encoding(conn) + escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + return b".".join(escs) + + +class Literal(Composable): + """ + A `Composable` representing an SQL value to include in a query. + + Usually you will want to include placeholders in the query and pass values + as `~cursor.execute()` arguments. If however you really really need to + include a literal value in the query you can use this object. + + The string returned by `!as_string()` follows the normal :ref:`adaptation + rules <types-adaptation>` for Python objects. + + Example:: + + >>> s1 = sql.Literal("fo'o") + >>> s2 = sql.Literal(42) + >>> s3 = sql.Literal(date(2000, 1, 1)) + >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn)) + 'fo''o', 42, '2000-01-01'::date + + """ + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + tx = Transformer.from_context(context) + return tx.as_literal(self._obj) + + +class Placeholder(Composable): + """A `Composable` representing a placeholder for query parameters. + + If the name is specified, generate a named placeholder (e.g. ``%(name)s``, + ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``, + ``%b``). + + The object is useful to generate SQL queries with a variable number of + arguments. + + Examples:: + + >>> names = ['foo', 'bar', 'baz'] + + >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(sql.Placeholder() * len(names))) + >>> print(q1.as_string(conn)) + INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s) + + >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(map(sql.Placeholder, names))) + >>> print(q2.as_string(conn)) + INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s) + + """ + + def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO): + super().__init__(name) + if not isinstance(name, str): + raise TypeError(f"expected string as name, got {name!r}") + + if ")" in name: + raise ValueError(f"invalid name: {name!r}") + + if type(format) is str: + format = PyFormat(format) + if not isinstance(format, PyFormat): + raise TypeError( + f"expected PyFormat as format, got {type(format).__name__!r}" + ) + + self._format: PyFormat = format + + def __repr__(self) -> str: + parts = [] + if self._obj: + parts.append(repr(self._obj)) + if self._format is not PyFormat.AUTO: + parts.append(f"format={self._format.name}") + + return f"{self.__class__.__name__}({', '.join(parts)})" + + def as_string(self, context: Optional[AdaptContext]) -> str: + code = self._format.value + return f"%({self._obj}){code}" if self._obj else f"%{code}" + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + conn = context.connection if context else None + enc = conn_encoding(conn) + return self.as_string(context).encode(enc) + + +# Literals +NULL = SQL("NULL") +DEFAULT = SQL("DEFAULT") diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py new file mode 100644 index 0000000..e13486e --- /dev/null +++ b/psycopg/psycopg/transaction.py @@ -0,0 +1,290 @@ +""" +Transaction context managers returned by Connection.transaction() +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from types import TracebackType +from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING + +from . import pq +from . import sql +from . import errors as e +from .abc import ConnectionType, PQGen + +if TYPE_CHECKING: + from typing import Any + from .connection import Connection + from .connection_async import AsyncConnection + +IDLE = pq.TransactionStatus.IDLE + +OK = pq.ConnStatus.OK + +logger = logging.getLogger(__name__) + + +class Rollback(Exception): + """ + Exit the current `Transaction` context immediately and rollback any changes + made within this context. + + If a transaction context is specified in the constructor, rollback + enclosing transactions contexts up to and including the one specified. + """ + + __module__ = "psycopg" + + def __init__( + self, + transaction: Union["Transaction", "AsyncTransaction", None] = None, + ): + self.transaction = transaction + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}({self.transaction!r})" + + +class OutOfOrderTransactionNesting(e.ProgrammingError): + """Out-of-order transaction nesting detected""" + + +class BaseTransaction(Generic[ConnectionType]): + def __init__( + self, + connection: ConnectionType, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ): + self._conn = connection + self.pgconn = self._conn.pgconn + self._savepoint_name = savepoint_name or "" + self.force_rollback = force_rollback + self._entered = self._exited = False + self._outer_transaction = False + self._stack_index = -1 + + @property + def savepoint_name(self) -> Optional[str]: + """ + The name of the savepoint; `!None` if handling the main transaction. + """ + # Yes, it may change on __enter__. No, I don't care, because the + # un-entered state is outside the public interface. + return self._savepoint_name + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self.pgconn) + if not self._entered: + status = "inactive" + elif not self._exited: + status = "active" + else: + status = "terminated" + + sp = f"{self.savepoint_name!r} " if self.savepoint_name else "" + return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>" + + def _enter_gen(self) -> PQGen[None]: + if self._entered: + raise TypeError("transaction blocks can be used only once") + self._entered = True + + self._push_savepoint() + for command in self._get_enter_commands(): + yield from self._conn._exec_command(command) + + def _exit_gen( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> PQGen[bool]: + if not exc_val and not self.force_rollback: + yield from self._commit_gen() + return False + else: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + return (yield from self._rollback_gen(exc_val)) + except OutOfOrderTransactionNesting: + # Clobber an exception happened in the block with the exception + # caused by out-of-order transaction detected, so make the + # behaviour consistent with _commit_gen and to make sure the + # user fixes this condition, which is unrelated from + # operational error that might arise in the block. + raise + except Exception as exc2: + logger.warning("error ignored in rollback of %s: %s", self, exc2) + return False + + def _commit_gen(self) -> PQGen[None]: + ex = self._pop_savepoint("commit") + self._exited = True + if ex: + raise ex + + for command in self._get_commit_commands(): + yield from self._conn._exec_command(command) + + def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: + if isinstance(exc_val, Rollback): + logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True) + + ex = self._pop_savepoint("rollback") + self._exited = True + if ex: + raise ex + + for command in self._get_rollback_commands(): + yield from self._conn._exec_command(command) + + if isinstance(exc_val, Rollback): + if not exc_val.transaction or exc_val.transaction is self: + return True # Swallow the exception + + return False + + def _get_enter_commands(self) -> Iterator[bytes]: + if self._outer_transaction: + yield self._conn._get_tx_start_command() + + if self._savepoint_name: + yield ( + sql.SQL("SAVEPOINT {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + def _get_commit_commands(self) -> Iterator[bytes]: + if self._savepoint_name and not self._outer_transaction: + yield ( + sql.SQL("RELEASE {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._num_transactions + yield b"COMMIT" + + def _get_rollback_commands(self) -> Iterator[bytes]: + if self._savepoint_name and not self._outer_transaction: + yield ( + sql.SQL("ROLLBACK TO {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + yield ( + sql.SQL("RELEASE {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._num_transactions + yield b"ROLLBACK" + + # Also clear the prepared statements cache. + if self._conn._prepared.clear(): + yield from self._conn._prepared.get_maintenance_commands() + + def _push_savepoint(self) -> None: + """ + Push the transaction on the connection transactions stack. + + Also set the internal state of the object and verify consistency. + """ + self._outer_transaction = self.pgconn.transaction_status == IDLE + if self._outer_transaction: + # outer transaction: if no name it's only a begin, else + # there will be an additional savepoint + assert not self._conn._num_transactions + else: + # inner transaction: it always has a name + if not self._savepoint_name: + self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}" + + self._stack_index = self._conn._num_transactions + self._conn._num_transactions += 1 + + def _pop_savepoint(self, action: str) -> Optional[Exception]: + """ + Pop the transaction from the connection transactions stack. + + Also verify the state consistency. + """ + self._conn._num_transactions -= 1 + if self._conn._num_transactions == self._stack_index: + return None + + return OutOfOrderTransactionNesting( + f"transaction {action} at the wrong nesting level: {self}" + ) + + +class Transaction(BaseTransaction["Connection[Any]"]): + """ + Returned by `Connection.transaction()` to handle a transaction block. + """ + + __module__ = "psycopg" + + _Self = TypeVar("_Self", bound="Transaction") + + @property + def connection(self) -> "Connection[Any]": + """The connection the object is managing.""" + return self._conn + + def __enter__(self: _Self) -> _Self: + with self._conn.lock: + self._conn.wait(self._enter_gen()) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + if self.pgconn.status == OK: + with self._conn.lock: + return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) + else: + return False + + +class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): + """ + Returned by `AsyncConnection.transaction()` to handle a transaction block. + """ + + __module__ = "psycopg" + + _Self = TypeVar("_Self", bound="AsyncTransaction") + + @property + def connection(self) -> "AsyncConnection[Any]": + return self._conn + + async def __aenter__(self: _Self) -> _Self: + async with self._conn.lock: + await self._conn.wait(self._enter_gen()) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + if self.pgconn.status == OK: + async with self._conn.lock: + return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) + else: + return False diff --git a/psycopg/psycopg/types/__init__.py b/psycopg/psycopg/types/__init__.py new file mode 100644 index 0000000..bdddf05 --- /dev/null +++ b/psycopg/psycopg/types/__init__.py @@ -0,0 +1,11 @@ +""" +psycopg types package +""" + +# Copyright (C) 2020 The Psycopg Team + +from .. import _typeinfo + +# Exposed here +TypeInfo = _typeinfo.TypeInfo +TypesRegistry = _typeinfo.TypesRegistry diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py new file mode 100644 index 0000000..e35c5e7 --- /dev/null +++ b/psycopg/psycopg/types/array.py @@ -0,0 +1,464 @@ +""" +Adapters for arrays +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type + +from .. import pq +from .. import errors as e +from .. import postgres +from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._compat import cache, prod +from .._struct import pack_len, unpack_len +from .._cmodule import _psycopg +from ..postgres import TEXT_OID, INVALID_OID +from .._typeinfo import TypeInfo + +_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid +_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack) +_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from) +_struct_dim = struct.Struct("!II") # dim, lower bound +_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack) +_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from) + +TEXT_ARRAY_OID = postgres.types["text"].array_oid + +PY_TEXT = PyFormat.TEXT +PQ_BINARY = pq.Format.BINARY + + +class BaseListDumper(RecursiveDumper): + element_oid = 0 + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + if cls is NoneType: + cls = list + + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + if self.element_oid and context: + sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format) + self.sub_dumper = sdclass(NoneType, context) + + def _find_list_element(self, L: List[Any], format: PyFormat) -> Any: + """ + Find the first non-null element of an eventually nested list + """ + items = list(self._flatiter(L, set())) + types = {type(item): item for item in items} + if not types: + return None + + if len(types) == 1: + t, v = types.popitem() + else: + # More than one type in the list. It might be still good, as long + # as they dump with the same oid (e.g. IPv4Network, IPv6Network). + dumpers = [self._tx.get_dumper(item, format) for item in types.values()] + oids = set(d.oid for d in dumpers) + if len(oids) == 1: + t, v = types.popitem() + else: + raise e.DataError( + "cannot dump lists of mixed types;" + f" got: {', '.join(sorted(t.__name__ for t in types))}" + ) + + # Checking for precise type. If the type is a subclass (e.g. Int4) + # we assume the user knows what type they are passing. + if t is not int: + return v + + # If we got an int, let's see what is the biggest one in order to + # choose the smallest OID and allow Postgres to do the right cast. + imax: int = max(items) + imin: int = min(items) + if imin >= 0: + return imax + else: + return max(imax, -imin - 1) + + def _flatiter(self, L: List[Any], seen: Set[int]) -> Any: + if id(L) in seen: + raise e.DataError("cannot dump a recursive list") + + seen.add(id(L)) + + for item in L: + if type(item) is list: + yield from self._flatiter(item, seen) + elif item is not None: + yield item + + return None + + def _get_base_type_info(self, base_oid: int) -> TypeInfo: + """ + Return info about the base type. + + Return text info as fallback. + """ + if base_oid: + info = self._tx.adapters.types.get(base_oid) + if info: + return info + + return self._tx.adapters.types["text"] + + +class ListDumper(BaseListDumper): + + delimiter = b"," + + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj, format) + if item is None: + return self.cls + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj, format) + if item is None: + # Empty lists can only be dumped as text if the type is unknown. + return self + + sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + + # We consider an array of unknowns as unknown, so we can dump empty + # lists or lists containing only None elements. + if sd.oid != INVALID_OID: + info = self._get_base_type_info(sd.oid) + dumper.oid = info.array_oid or TEXT_ARRAY_OID + dumper.delimiter = info.delimiter.encode() + else: + dumper.oid = INVALID_OID + + return dumper + + # Double quotes and backslashes embedded in element values will be + # backslash-escaped. + _re_esc = re.compile(rb'(["\\])') + + def dump(self, obj: List[Any]) -> bytes: + tokens: List[Buffer] = [] + needs_quotes = _get_needs_quotes_regexp(self.delimiter).search + + def dump_list(obj: List[Any]) -> None: + if not obj: + tokens.append(b"{}") + return + + tokens.append(b"{") + for item in obj: + if isinstance(item, list): + dump_list(item) + elif item is not None: + ad = self._dump_item(item) + if needs_quotes(ad): + if not isinstance(ad, bytes): + ad = bytes(ad) + ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"' + tokens.append(ad) + else: + tokens.append(b"NULL") + + tokens.append(self.delimiter) + + tokens[-1] = b"}" + + dump_list(obj) + + return b"".join(tokens) + + def _dump_item(self, item: Any) -> Buffer: + if self.sub_dumper: + return self.sub_dumper.dump(item) + else: + return self._tx.get_dumper(item, PY_TEXT).dump(item) + + +@cache +def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]: + """Return a regexp to recognise when a value needs quotes + + from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO + + The array output routine will put double quotes around element values if + they are empty strings, contain curly braces, delimiter characters, + double quotes, backslashes, or white space, or match the word NULL. + """ + return re.compile( + rb"""(?xi) + ^$ # the empty string + | ["{}%s\\\s] # or a char to escape + | ^null$ # or the word NULL + """ + % delimiter + ) + + +class ListBinaryDumper(BaseListDumper): + + format = pq.Format.BINARY + + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj, format) + if item is None: + return (self.cls,) + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj, format) + if item is None: + return ListDumper(self.cls, self._tx) + + sd = self._tx.get_dumper(item, format.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + info = self._get_base_type_info(sd.oid) + dumper.oid = info.array_oid or TEXT_ARRAY_OID + + return dumper + + def dump(self, obj: List[Any]) -> bytes: + # Postgres won't take unknown for element oid: fall back on text + sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID + + if not obj: + return _pack_head(0, 0, sub_oid) + + data: List[Buffer] = [b"", b""] # placeholders to avoid a resize + dims: List[int] = [] + hasnull = 0 + + def calc_dims(L: List[Any]) -> None: + if isinstance(L, self.cls): + if not L: + raise e.DataError("lists cannot contain empty lists") + dims.append(len(L)) + calc_dims(L[0]) + + calc_dims(obj) + + def dump_list(L: List[Any], dim: int) -> None: + nonlocal hasnull + if len(L) != dims[dim]: + raise e.DataError("nested lists have inconsistent lengths") + + if dim == len(dims) - 1: + for item in L: + if item is not None: + # If we get here, the sub_dumper must have been set + ad = self.sub_dumper.dump(item) # type: ignore[union-attr] + data.append(pack_len(len(ad))) + data.append(ad) + else: + hasnull = 1 + data.append(b"\xff\xff\xff\xff") + else: + for item in L: + if not isinstance(item, self.cls): + raise e.DataError("nested lists have inconsistent depths") + dump_list(item, dim + 1) # type: ignore + + dump_list(obj, 0) + + data[0] = _pack_head(len(dims), hasnull, sub_oid) + data[1] = b"".join(_pack_dim(dim, 1) for dim in dims) + return b"".join(data) + + +class ArrayLoader(RecursiveLoader): + + delimiter = b"," + base_oid: int + + def load(self, data: Buffer) -> List[Any]: + loader = self._tx.get_loader(self.base_oid, self.format) + return _load_text(data, loader, self.delimiter) + + +class ArrayBinaryLoader(RecursiveLoader): + + format = pq.Format.BINARY + + def load(self, data: Buffer) -> List[Any]: + return _load_binary(data, self._tx) + + +def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + if not info.array_oid: + raise ValueError(f"the type info {info} doesn't describe an array") + + base: Type[Any] + adapters = context.adapters if context else postgres.adapters + + base = getattr(_psycopg, "ArrayLoader", ArrayLoader) + name = f"{info.name.title()}{base.__name__}" + attribs = { + "base_oid": info.oid, + "delimiter": info.delimiter.encode(), + } + loader = type(name, (base,), attribs) + adapters.register_loader(info.array_oid, loader) + + loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader) + adapters.register_loader(info.array_oid, loader) + + base = ListDumper + name = f"{info.name.title()}{base.__name__}" + attribs = { + "oid": info.array_oid, + "element_oid": info.oid, + "delimiter": info.delimiter.encode(), + } + dumper = type(name, (base,), attribs) + adapters.register_dumper(None, dumper) + + base = ListBinaryDumper + name = f"{info.name.title()}{base.__name__}" + attribs = { + "oid": info.array_oid, + "element_oid": info.oid, + } + dumper = type(name, (base,), attribs) + adapters.register_dumper(None, dumper) + + +def register_default_adapters(context: AdaptContext) -> None: + # The text dumper is more flexible as it can handle lists of mixed type, + # so register it later. + context.adapters.register_dumper(list, ListBinaryDumper) + context.adapters.register_dumper(list, ListDumper) + + +def register_all_arrays(context: AdaptContext) -> None: + """ + Associate the array oid of all the types in Loader.globals. + + This function is designed to be called once at import time, after having + registered all the base loaders. + """ + for t in context.adapters.types: + if t.array_oid: + t.register(context) + + +def _load_text( + data: Buffer, + loader: Loader, + delimiter: bytes = b",", + __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"), +) -> List[Any]: + rv = None + stack: List[Any] = [] + a: List[Any] = [] + rv = a + load = loader.load + + # Remove the dimensions information prefix (``[...]=``) + if data and data[0] == b"["[0]: + if isinstance(data, memoryview): + data = bytes(data) + idx = data.find(b"=") + if idx == -1: + raise e.DataError("malformed array: no '=' after dimension information") + data = data[idx + 1 :] + + re_parse = _get_array_parse_regexp(delimiter) + for m in re_parse.finditer(data): + t = m.group(1) + if t == b"{": + if stack: + stack[-1].append(a) + stack.append(a) + a = [] + + elif t == b"}": + if not stack: + raise e.DataError("malformed array: unexpected '}'") + rv = stack.pop() + + else: + if not stack: + wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else "" + raise e.DataError(f"malformed array: unexpected '{wat}'") + if t == b"NULL": + v = None + else: + if t.startswith(b'"'): + t = __re_unescape.sub(rb"\1", t[1:-1]) + v = load(t) + + stack[-1].append(v) + + assert rv is not None + return rv + + +@cache +def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]: + """ + Return a regexp to tokenize an array representation into item and brackets + """ + return re.compile( + rb"""(?xi) + ( [{}] # open or closed bracket + | " (?: [^"\\] | \\. )* " # or a quoted string + | [^"{}%s\\]+ # or an unquoted non-empty string + ) ,? + """ + % delimiter + ) + + +def _load_binary(data: Buffer, tx: Transformer) -> List[Any]: + ndims, hasnull, oid = _unpack_head(data) + load = tx.get_loader(oid, PQ_BINARY).load + + if not ndims: + return [] + + p = 12 + 8 * ndims + dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)] + nelems = prod(dims) + + out: List[Any] = [None] * nelems + for i in range(nelems): + size = unpack_len(data, p)[0] + p += 4 + if size == -1: + continue + out[i] = load(data[p : p + size]) + p += size + + # fon ndims > 1 we have to aggregate the array into sub-arrays + for dim in dims[-1:0:-1]: + out = [out[i : i + dim] for i in range(0, len(out), dim)] + + return out diff --git a/psycopg/psycopg/types/bool.py b/psycopg/psycopg/types/bool.py new file mode 100644 index 0000000..db7e181 --- /dev/null +++ b/psycopg/psycopg/types/bool.py @@ -0,0 +1,51 @@ +""" +Adapters for booleans. +""" + +# Copyright (C) 2020 The Psycopg Team + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + + +class BoolDumper(Dumper): + + oid = postgres.types["bool"].oid + + def dump(self, obj: bool) -> bytes: + return b"t" if obj else b"f" + + def quote(self, obj: bool) -> bytes: + return b"true" if obj else b"false" + + +class BoolBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["bool"].oid + + def dump(self, obj: bool) -> bytes: + return b"\x01" if obj else b"\x00" + + +class BoolLoader(Loader): + def load(self, data: Buffer) -> bool: + return data == b"t" + + +class BoolBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> bool: + return data != b"\x00" + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(bool, BoolDumper) + adapters.register_dumper(bool, BoolBinaryDumper) + adapters.register_loader("bool", BoolLoader) + adapters.register_loader("bool", BoolBinaryLoader) diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py new file mode 100644 index 0000000..1c609c3 --- /dev/null +++ b/psycopg/psycopg/types/composite.py @@ -0,0 +1,290 @@ +""" +Support for composite types adaptation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from collections import namedtuple +from typing import Any, Callable, cast, Iterator, List, Optional +from typing import Sequence, Tuple, Type + +from .. import pq +from .. import postgres +from ..abc import AdaptContext, Buffer +from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader +from .._struct import pack_len, unpack_len +from ..postgres import TEXT_OID +from .._typeinfo import CompositeInfo as CompositeInfo # exported here +from .._encodings import _as_python_identifier + +_struct_oidlen = struct.Struct("!Ii") +_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack) +_unpack_oidlen = cast( + Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from +) + + +class SequenceDumper(RecursiveDumper): + def _dump_sequence( + self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes + ) -> bytes: + if not obj: + return start + end + + parts: List[Buffer] = [start] + + for item in obj: + if item is None: + parts.append(sep) + continue + + dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) + ad = dumper.dump(item) + if not ad: + ad = b'""' + elif self._re_needs_quotes.search(ad): + ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"' + + parts.append(ad) + parts.append(sep) + + parts[-1] = end + + return b"".join(parts) + + _re_needs_quotes = re.compile(rb'[",\\\s()]') + _re_esc = re.compile(rb"([\\\"])") + + +class TupleDumper(SequenceDumper): + + # Should be this, but it doesn't work + # oid = postgres_types["record"].oid + + def dump(self, obj: Tuple[Any, ...]) -> bytes: + return self._dump_sequence(obj, b"(", b")", b",") + + +class TupleBinaryDumper(RecursiveDumper): + + format = pq.Format.BINARY + + # Subclasses must set an info + info: CompositeInfo + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + nfields = len(self.info.field_types) + self._tx.set_dumper_types(self.info.field_types, self.format) + self._formats = (PyFormat.from_pq(self.format),) * nfields + + def dump(self, obj: Tuple[Any, ...]) -> bytearray: + out = bytearray(pack_len(len(obj))) + adapted = self._tx.dump_sequence(obj, self._formats) + for i in range(len(obj)): + b = adapted[i] + oid = self.info.field_types[i] + if b is not None: + out += _pack_oidlen(oid, len(b)) + out += b + else: + out += _pack_oidlen(oid, -1) + + return out + + +class BaseCompositeLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer(context) + + def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]: + """ + Split a non-empty representation of a composite type into components. + + Terminators shouldn't be used in `!data` (so that both record and range + representations can be parsed). + """ + for m in self._re_tokenize.finditer(data): + if m.group(1): + yield None + elif m.group(2) is not None: + yield self._re_undouble.sub(rb"\1", m.group(2)) + else: + yield m.group(3) + + # If the final group ended in `,` there is a final NULL in the record + # that the regexp couldn't parse. + if m and m.group().endswith(b","): + yield None + + _re_tokenize = re.compile( + rb"""(?x) + (,) # an empty token, representing NULL + | " ((?: [^"] | "")*) " ,? # or a quoted string + | ([^",)]+) ,? # or an unquoted string + """ + ) + + _re_undouble = re.compile(rb'(["\\])\1') + + +class RecordLoader(BaseCompositeLoader): + def load(self, data: Buffer) -> Tuple[Any, ...]: + if data == b"()": + return () + + cast = self._tx.get_loader(TEXT_OID, self.format).load + return tuple( + cast(token) if token is not None else None + for token in self._parse_record(data[1:-1]) + ) + + +class RecordBinaryLoader(Loader): + format = pq.Format.BINARY + _types_set = False + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer(context) + + def load(self, data: Buffer) -> Tuple[Any, ...]: + if not self._types_set: + self._config_types(data) + self._types_set = True + + return self._tx.load_sequence( + tuple( + data[offset : offset + length] if length != -1 else None + for _, offset, length in self._walk_record(data) + ) + ) + + def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]: + """ + Yield a sequence of (oid, offset, length) for the content of the record + """ + nfields = unpack_len(data, 0)[0] + i = 4 + for _ in range(nfields): + oid, length = _unpack_oidlen(data, i) + yield oid, i + 8, length + i += (8 + length) if length > 0 else 8 + + def _config_types(self, data: Buffer) -> None: + oids = [r[0] for r in self._walk_record(data)] + self._tx.set_loader_types(oids, self.format) + + +class CompositeLoader(RecordLoader): + + factory: Callable[..., Any] + fields_types: List[int] + _types_set = False + + def load(self, data: Buffer) -> Any: + if not self._types_set: + self._config_types(data) + self._types_set = True + + if data == b"()": + return type(self).factory() + + return type(self).factory( + *self._tx.load_sequence(tuple(self._parse_record(data[1:-1]))) + ) + + def _config_types(self, data: Buffer) -> None: + self._tx.set_loader_types(self.fields_types, self.format) + + +class CompositeBinaryLoader(RecordBinaryLoader): + + format = pq.Format.BINARY + factory: Callable[..., Any] + + def load(self, data: Buffer) -> Any: + r = super().load(data) + return type(self).factory(*r) + + +def register_composite( + info: CompositeInfo, + context: Optional[AdaptContext] = None, + factory: Optional[Callable[..., Any]] = None, +) -> None: + """Register the adapters to load and dump a composite type. + + :param info: The object with the information about the composite to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + :param factory: Callable to convert the sequence of attributes read from + the composite into a Python object. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested composite available?") + + # Register arrays and type info + info.register(context) + + if not factory: + factory = namedtuple( # type: ignore + _as_python_identifier(info.name), + [_as_python_identifier(n) for n in info.field_names], + ) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[BaseCompositeLoader] = type( + f"{info.name.title()}Loader", + (CompositeLoader,), + { + "factory": factory, + "fields_types": info.field_types, + }, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + loader = type( + f"{info.name.title()}BinaryLoader", + (CompositeBinaryLoader,), + {"factory": factory}, + ) + adapters.register_loader(info.oid, loader) + + # If the factory is a type, create and register dumpers for it + if isinstance(factory, type): + dumper = type( + f"{info.name.title()}BinaryDumper", + (TupleBinaryDumper,), + {"oid": info.oid, "info": info}, + ) + adapters.register_dumper(factory, dumper) + + # Default to the text dumper because it is more flexible + dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid}) + adapters.register_dumper(factory, dumper) + + info.python_type = factory + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(tuple, TupleDumper) + adapters.register_loader("record", RecordLoader) + adapters.register_loader("record", RecordBinaryLoader) diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py new file mode 100644 index 0000000..f0dfe83 --- /dev/null +++ b/psycopg/psycopg/types/datetime.py @@ -0,0 +1,754 @@ +""" +Adapters for date/time types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from datetime import date, datetime, time, timedelta, timezone +from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING + +from .. import postgres +from ..pq import Format +from .._tz import get_tzinfo +from ..abc import AdaptContext, DumperKey +from ..adapt import Buffer, Dumper, Loader, PyFormat +from ..errors import InterfaceError, DataError +from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8 + +if TYPE_CHECKING: + from ..connection import BaseConnection + +_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset +_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack) +_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack) + +_struct_interval = struct.Struct("!qii") # microseconds, days, months +_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack) +_unpack_interval = cast( + Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack +) + +utc = timezone.utc +_pg_date_epoch_days = date(2000, 1, 1).toordinal() +_pg_datetime_epoch = datetime(2000, 1, 1) +_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc) +_py_date_min_days = date.min.toordinal() + + +class DateDumper(Dumper): + + oid = postgres.types["date"].oid + + def dump(self, obj: date) -> bytes: + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + return str(obj).encode() + + +class DateBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["date"].oid + + def dump(self, obj: date) -> bytes: + days = obj.toordinal() - _pg_date_epoch_days + return pack_int4(days) + + +class _BaseTimeDumper(Dumper): + def get_key(self, obj: time, format: PyFormat) -> DumperKey: + # Use (cls,) to report the need to upgrade to a dumper for timetz (the + # Frankenstein of the data types). + if not obj.tzinfo: + return self.cls + else: + return (self.cls,) + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + raise NotImplementedError + + +class _BaseTimeTextDumper(_BaseTimeDumper): + def dump(self, obj: time) -> bytes: + return str(obj).encode() + + +class TimeDumper(_BaseTimeTextDumper): + + oid = postgres.types["time"].oid + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + if not obj.tzinfo: + return self + else: + return TimeTzDumper(self.cls) + + +class TimeTzDumper(_BaseTimeTextDumper): + + oid = postgres.types["timetz"].oid + + +class TimeBinaryDumper(_BaseTimeDumper): + + format = Format.BINARY + oid = postgres.types["time"].oid + + def dump(self, obj: time) -> bytes: + us = obj.microsecond + 1_000_000 * ( + obj.second + 60 * (obj.minute + 60 * obj.hour) + ) + return pack_int8(us) + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + if not obj.tzinfo: + return self + else: + return TimeTzBinaryDumper(self.cls) + + +class TimeTzBinaryDumper(_BaseTimeDumper): + + format = Format.BINARY + oid = postgres.types["timetz"].oid + + def dump(self, obj: time) -> bytes: + us = obj.microsecond + 1_000_000 * ( + obj.second + 60 * (obj.minute + 60 * obj.hour) + ) + off = obj.utcoffset() + assert off is not None + return _pack_timetz(us, -int(off.total_seconds())) + + +class _BaseDatetimeDumper(Dumper): + def get_key(self, obj: datetime, format: PyFormat) -> DumperKey: + # Use (cls,) to report the need to upgrade (downgrade, actually) to a + # dumper for naive timestamp. + if obj.tzinfo: + return self.cls + else: + return (self.cls,) + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + raise NotImplementedError + + +class _BaseDatetimeTextDumper(_BaseDatetimeDumper): + def dump(self, obj: datetime) -> bytes: + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + return str(obj).encode() + + +class DatetimeDumper(_BaseDatetimeTextDumper): + + oid = postgres.types["timestamptz"].oid + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + if obj.tzinfo: + return self + else: + return DatetimeNoTzDumper(self.cls) + + +class DatetimeNoTzDumper(_BaseDatetimeTextDumper): + + oid = postgres.types["timestamp"].oid + + +class DatetimeBinaryDumper(_BaseDatetimeDumper): + + format = Format.BINARY + oid = postgres.types["timestamptz"].oid + + def dump(self, obj: datetime) -> bytes: + delta = obj - _pg_datetimetz_epoch + micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds) + return pack_int8(micros) + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + if obj.tzinfo: + return self + else: + return DatetimeNoTzBinaryDumper(self.cls) + + +class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper): + + format = Format.BINARY + oid = postgres.types["timestamp"].oid + + def dump(self, obj: datetime) -> bytes: + delta = obj - _pg_datetime_epoch + micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds) + return pack_int8(micros) + + +class TimedeltaDumper(Dumper): + + oid = postgres.types["interval"].oid + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + if self.connection: + if ( + self.connection.pgconn.parameter_status(b"IntervalStyle") + == b"sql_standard" + ): + setattr(self, "dump", self._dump_sql) + + def dump(self, obj: timedelta) -> bytes: + # The comma is parsed ok by PostgreSQL but it's not documented + # and it seems brittle to rely on it. CRDB doesn't consume it well. + return str(obj).encode().replace(b",", b"") + + def _dump_sql(self, obj: timedelta) -> bytes: + # sql_standard format needs explicit signs + # otherwise -1 day 1 sec will mean -1 sec + return b"%+d day %+d second %+d microsecond" % ( + obj.days, + obj.seconds, + obj.microseconds, + ) + + +class TimedeltaBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["interval"].oid + + def dump(self, obj: timedelta) -> bytes: + micros = 1_000_000 * obj.seconds + obj.microseconds + return _pack_interval(micros, obj.days, 0) + + +class DateLoader(Loader): + + _ORDER_YMD = 0 + _ORDER_DMY = 1 + _ORDER_MDY = 2 + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + ds = _get_datestyle(self.connection) + if ds.startswith(b"I"): # ISO + self._order = self._ORDER_YMD + elif ds.startswith(b"G"): # German + self._order = self._ORDER_DMY + elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres + self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY + else: + raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + def load(self, data: Buffer) -> date: + if self._order == self._ORDER_YMD: + ye = data[:4] + mo = data[5:7] + da = data[8:] + elif self._order == self._ORDER_DMY: + da = data[:2] + mo = data[3:5] + ye = data[6:] + else: + mo = data[:2] + da = data[3:5] + ye = data[6:] + + try: + return date(int(ye), int(mo), int(da)) + except ValueError as ex: + s = bytes(data).decode("utf8", "replace") + if s == "infinity" or (s and len(s.split()[0]) > 10): + raise DataError(f"date too large (after year 10K): {s!r}") from None + elif s == "-infinity" or "BC" in s: + raise DataError(f"date too small (before year 1): {s!r}") from None + else: + raise DataError(f"can't parse date {s!r}: {ex}") from None + + +class DateBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> date: + days = unpack_int4(data)[0] + _pg_date_epoch_days + try: + return date.fromordinal(days) + except (ValueError, OverflowError): + if days < _py_date_min_days: + raise DataError("date too small (before year 1)") from None + else: + raise DataError("date too large (after year 10K)") from None + + +class TimeLoader(Loader): + + _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?") + + def load(self, data: Buffer) -> time: + m = self._re_format.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse time {s!r}") + + ho, mi, se, fr = m.groups() + + # Pad the fraction of second to get micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + try: + return time(int(ho), int(mi), int(se), us) + except ValueError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse time {s!r}: {e}") from None + + +class TimeBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> time: + val = unpack_int8(data)[0] + val, us = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + try: + return time(h, m, s, us) + except ValueError: + raise DataError(f"time not supported by Python: hour={h}") from None + + +class TimetzLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros + ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone + $ + """ + ) + + def load(self, data: Buffer) -> time: + m = self._re_format.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse timetz {s!r}") + + ho, mi, se, fr, sgn, oh, om, os = m.groups() + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + # Calculate timezone + off = 60 * 60 * int(oh) + if om: + off += 60 * int(om) + if os: + off += int(os) + tz = timezone(timedelta(0, off if sgn == b"+" else -off)) + + try: + return time(int(ho), int(mi), int(se), us, tz) + except ValueError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse timetz {s!r}: {e}") from None + + +class TimetzBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> time: + val, off = _unpack_timetz(data) + + val, us = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + + try: + return time(h, m, s, us, timezone(timedelta(seconds=-off))) + except ValueError: + raise DataError(f"time not supported by Python: hour={h}") from None + + +class TimestampLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date + (?: T | [^a-z0-9] ) # Separator, including T + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + $ + """ + ) + _re_format_pg = re.compile( + rb"""(?ix) + ^ + [a-z]+ [^a-z0-9] # DoW, separator + (\d+|[a-z]+) [^a-z0-9] # Month or day + (\d+|[a-z]+) [^a-z0-9] # Month or day + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + [^a-z0-9] (\d+) # Year + $ + """ + ) + + _ORDER_YMD = 0 + _ORDER_DMY = 1 + _ORDER_MDY = 2 + _ORDER_PGDM = 3 + _ORDER_PGMD = 4 + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + + ds = _get_datestyle(self.connection) + if ds.startswith(b"I"): # ISO + self._order = self._ORDER_YMD + elif ds.startswith(b"G"): # German + self._order = self._ORDER_DMY + elif ds.startswith(b"S"): # SQL + self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY + elif ds.startswith(b"P"): # Postgres + self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD + self._re_format = self._re_format_pg + else: + raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + def load(self, data: Buffer) -> datetime: + m = self._re_format.match(data) + if not m: + raise _get_timestamp_load_error(self.connection, data) from None + + if self._order == self._ORDER_YMD: + ye, mo, da, ho, mi, se, fr = m.groups() + imo = int(mo) + elif self._order == self._ORDER_DMY: + da, mo, ye, ho, mi, se, fr = m.groups() + imo = int(mo) + elif self._order == self._ORDER_MDY: + mo, da, ye, ho, mi, se, fr = m.groups() + imo = int(mo) + else: + if self._order == self._ORDER_PGDM: + da, mo, ho, mi, se, fr, ye = m.groups() + else: + mo, da, ho, mi, se, fr, ye = m.groups() + try: + imo = _month_abbr[mo] + except KeyError: + s = mo.decode("utf8", "replace") + raise DataError(f"can't parse month: {s!r}") from None + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + try: + return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us) + except ValueError as ex: + raise _get_timestamp_load_error(self.connection, data, ex) from None + + +class TimestampBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> datetime: + micros = unpack_int8(data)[0] + try: + return _pg_datetime_epoch + timedelta(microseconds=micros) + except OverflowError: + if micros <= 0: + raise DataError("timestamp too small (before year 1)") from None + else: + raise DataError("timestamp too large (after year 10K)") from None + + +class TimestamptzLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date + (?: T | [^a-z0-9] ) # Separator, including T + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone + $ + """ + ) + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) + + ds = _get_datestyle(self.connection) + if not ds.startswith(b"I"): # not ISO + setattr(self, "load", self._load_notimpl) + + def load(self, data: Buffer) -> datetime: + m = self._re_format.match(data) + if not m: + raise _get_timestamp_load_error(self.connection, data) from None + + ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups() + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + # Calculate timezone offset + soff = 60 * 60 * int(oh) + if om: + soff += 60 * int(om) + if os: + soff += int(os) + tzoff = timedelta(0, soff if sgn == b"+" else -soff) + + # The return value is a datetime with the timezone of the connection + # (in order to be consistent with the binary loader, which is the only + # thing it can return). So create a temporary datetime object, in utc, + # shift it by the offset parsed from the timestamp, and then move it to + # the connection timezone. + dt = None + ex: Exception + try: + dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc) + return (dt - tzoff).astimezone(self._timezone) + except OverflowError as e: + # If we have created the temporary 'dt' it means that we have a + # datetime close to max, the shift pushed it past max, overflowing. + # In this case return the datetime in a fixed offset timezone. + if dt is not None: + return dt.replace(tzinfo=timezone(tzoff)) + else: + ex = e + except ValueError as e: + ex = e + + raise _get_timestamp_load_error(self.connection, data, ex) from None + + def _load_notimpl(self, data: Buffer) -> datetime: + s = bytes(data).decode("utf8", "replace") + ds = _get_datestyle(self.connection).decode("ascii") + raise NotImplementedError( + f"can't parse timestamptz with DateStyle {ds!r}: {s!r}" + ) + + +class TimestamptzBinaryLoader(Loader): + + format = Format.BINARY + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) + + def load(self, data: Buffer) -> datetime: + micros = unpack_int8(data)[0] + try: + ts = _pg_datetimetz_epoch + timedelta(microseconds=micros) + return ts.astimezone(self._timezone) + except OverflowError: + # If we were asked about a timestamp which would overflow in UTC, + # but not in the desired timezone (e.g. datetime.max at Chicago + # timezone) we can still save the day by shifting the value by the + # timezone offset and then replacing the timezone. + if self._timezone: + utcoff = self._timezone.utcoffset( + datetime.min if micros < 0 else datetime.max + ) + if utcoff: + usoff = 1_000_000 * int(utcoff.total_seconds()) + try: + ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff) + except OverflowError: + pass # will raise downstream + else: + return ts.replace(tzinfo=self._timezone) + + if micros <= 0: + raise DataError("timestamp too small (before year 1)") from None + else: + raise DataError("timestamp too large (after year 10K)") from None + + +class IntervalLoader(Loader): + + _re_interval = re.compile( + rb""" + (?: ([-+]?\d+) \s+ years? \s* )? # Years + (?: ([-+]?\d+) \s+ mons? \s* )? # Months + (?: ([-+]?\d+) \s+ days? \s* )? # Days + (?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time + )? + """, + re.VERBOSE, + ) + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + if self.connection: + ints = self.connection.pgconn.parameter_status(b"IntervalStyle") + if ints != b"postgres": + setattr(self, "load", self._load_notimpl) + + def load(self, data: Buffer) -> timedelta: + m = self._re_interval.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse interval {s!r}") + + ye, mo, da, sgn, ho, mi, se = m.groups() + days = 0 + seconds = 0.0 + + if ye: + days += 365 * int(ye) + if mo: + days += 30 * int(mo) + if da: + days += int(da) + + if ho: + seconds = 3600 * int(ho) + 60 * int(mi) + float(se) + if sgn == b"-": + seconds = -seconds + + try: + return timedelta(days=days, seconds=seconds) + except OverflowError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse interval {s!r}: {e}") from None + + def _load_notimpl(self, data: Buffer) -> timedelta: + s = bytes(data).decode("utf8", "replace") + ints = ( + self.connection + and self.connection.pgconn.parameter_status(b"IntervalStyle") + or b"unknown" + ).decode("utf8", "replace") + raise NotImplementedError( + f"can't parse interval with IntervalStyle {ints}: {s!r}" + ) + + +class IntervalBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> timedelta: + micros, days, months = _unpack_interval(data) + if months > 0: + years, months = divmod(months, 12) + days = days + 30 * months + 365 * years + elif months < 0: + years, months = divmod(-months, 12) + days = days - 30 * months - 365 * years + + try: + return timedelta(days=days, microseconds=micros) + except OverflowError as e: + raise DataError(f"can't parse interval: {e}") from None + + +def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes: + if conn: + ds = conn.pgconn.parameter_status(b"DateStyle") + if ds: + return ds + + return b"ISO, DMY" + + +def _get_timestamp_load_error( + conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None +) -> Exception: + s = bytes(data).decode("utf8", "replace") + + def is_overflow(s: str) -> bool: + if not s: + return False + + ds = _get_datestyle(conn) + if not ds.startswith(b"P"): # Postgres + return len(s.split()[0]) > 10 # date is first token + else: + return len(s.split()[-1]) > 4 # year is last token + + if s == "-infinity" or s.endswith("BC"): + return DataError("timestamp too small (before year 1): {s!r}") + elif s == "infinity" or is_overflow(s): + return DataError(f"timestamp too large (after year 10K): {s!r}") + else: + return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}") + + +_month_abbr = { + n: i + for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1) +} + +# Pad to get microseconds from a fraction of seconds +_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1] + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("datetime.date", DateDumper) + adapters.register_dumper("datetime.date", DateBinaryDumper) + + # first register dumpers for 'timetz' oid, then the proper ones on time type. + adapters.register_dumper("datetime.time", TimeTzDumper) + adapters.register_dumper("datetime.time", TimeTzBinaryDumper) + adapters.register_dumper("datetime.time", TimeDumper) + adapters.register_dumper("datetime.time", TimeBinaryDumper) + + # first register dumpers for 'timestamp' oid, then the proper ones + # on the datetime type. + adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper) + adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper) + adapters.register_dumper("datetime.datetime", DatetimeDumper) + adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper) + + adapters.register_dumper("datetime.timedelta", TimedeltaDumper) + adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper) + + adapters.register_loader("date", DateLoader) + adapters.register_loader("date", DateBinaryLoader) + adapters.register_loader("time", TimeLoader) + adapters.register_loader("time", TimeBinaryLoader) + adapters.register_loader("timetz", TimetzLoader) + adapters.register_loader("timetz", TimetzBinaryLoader) + adapters.register_loader("timestamp", TimestampLoader) + adapters.register_loader("timestamp", TimestampBinaryLoader) + adapters.register_loader("timestamptz", TimestamptzLoader) + adapters.register_loader("timestamptz", TimestamptzBinaryLoader) + adapters.register_loader("interval", IntervalLoader) + adapters.register_loader("interval", IntervalBinaryLoader) diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py new file mode 100644 index 0000000..d3c7387 --- /dev/null +++ b/psycopg/psycopg/types/enum.py @@ -0,0 +1,177 @@ +""" +Adapters for the enum type. +""" +from enum import Enum +from typing import Any, Dict, Generic, Optional, Mapping, Sequence +from typing import Tuple, Type, TypeVar, Union, cast +from typing_extensions import TypeAlias + +from .. import postgres +from .. import errors as e +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader +from .._encodings import conn_encoding +from .._typeinfo import EnumInfo as EnumInfo # exported here + +E = TypeVar("E", bound=Enum) + +EnumDumpMap: TypeAlias = Dict[E, bytes] +EnumLoadMap: TypeAlias = Dict[bytes, E] +EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None] + + +class _BaseEnumLoader(Loader, Generic[E]): + """ + Loader for a specific Enum class + """ + + enum: Type[E] + _load_map: EnumLoadMap[E] + + def load(self, data: Buffer) -> E: + if not isinstance(data, bytes): + data = bytes(data) + + try: + return self._load_map[data] + except KeyError: + enc = conn_encoding(self.connection) + label = data.decode(enc, "replace") + raise e.DataError( + f"bad member for enum {self.enum.__qualname__}: {label!r}" + ) + + +class _BaseEnumDumper(Dumper, Generic[E]): + """ + Dumper for a specific Enum class + """ + + enum: Type[E] + _dump_map: EnumDumpMap[E] + + def dump(self, value: E) -> Buffer: + return self._dump_map[value] + + +class EnumDumper(Dumper): + """ + Dumper for a generic Enum class + """ + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._encoding = conn_encoding(self.connection) + + def dump(self, value: E) -> Buffer: + return value.name.encode(self._encoding) + + +class EnumBinaryDumper(EnumDumper): + format = Format.BINARY + + +def register_enum( + info: EnumInfo, + context: Optional[AdaptContext] = None, + enum: Optional[Type[E]] = None, + *, + mapping: EnumMapping[E] = None, +) -> None: + """Register the adapters to load and dump a enum type. + + :param info: The object with the information about the enum to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + :param enum: Python enum type matching to the PostgreSQL one. If `!None`, + a new enum will be generated and exposed as `EnumInfo.enum`. + :param mapping: Override the mapping between `!enum` members and `!info` + labels. + """ + + if not info: + raise TypeError("no info passed. Is the requested enum available?") + + if enum is None: + enum = cast(Type[E], Enum(info.name.title(), info.labels, module=__name__)) + + info.enum = enum + adapters = context.adapters if context else postgres.adapters + info.register(context) + + load_map = _make_load_map(info, enum, mapping, context) + attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map} + + name = f"{info.name.title()}Loader" + loader = type(name, (_BaseEnumLoader,), attribs) + adapters.register_loader(info.oid, loader) + + name = f"{info.name.title()}BinaryLoader" + loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY}) + adapters.register_loader(info.oid, loader) + + dump_map = _make_dump_map(info, enum, mapping, context) + attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map} + + name = f"{enum.__name__}Dumper" + dumper = type(name, (_BaseEnumDumper,), attribs) + adapters.register_dumper(info.enum, dumper) + + name = f"{enum.__name__}BinaryDumper" + dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY}) + adapters.register_dumper(info.enum, dumper) + + +def _make_load_map( + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumLoadMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumLoadMap[E] = {} + for label in info.labels: + try: + member = enum[label] + except KeyError: + # tolerate a missing enum, assuming it won't be used. If it is we + # will get a DataError on fetch. + pass + else: + rv[label.encode(enc)] = member + + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[label.encode(enc)] = member + + return rv + + +def _make_dump_map( + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumDumpMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumDumpMap[E] = {} + for member in enum: + rv[member] = member.name.encode(enc) + + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[member] = label.encode(enc) + + return rv + + +def register_default_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, EnumBinaryDumper) + context.adapters.register_dumper(Enum, EnumDumper) diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py new file mode 100644 index 0000000..e1ab1d5 --- /dev/null +++ b/psycopg/psycopg/types/hstore.py @@ -0,0 +1,131 @@ +""" +Dict to hstore adaptation +""" + +# Copyright (C) 2021 The Psycopg Team + +import re +from typing import Dict, List, Optional +from typing_extensions import TypeAlias + +from .. import errors as e +from .. import postgres +from ..abc import Buffer, AdaptContext +from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader +from ..postgres import TEXT_OID +from .._typeinfo import TypeInfo + +_re_escape = re.compile(r'(["\\])') +_re_unescape = re.compile(r"\\(.)") + +_re_hstore = re.compile( + r""" + # hstore key: + # a string of normal or escaped chars + "((?: [^"\\] | \\. )*)" + \s*=>\s* # hstore value + (?: + NULL # the value can be null - not caught + # or a quoted string like the key + | "((?: [^"\\] | \\. )*)" + ) + (?:\s*,\s*|$) # pairs separated by comma or end of string. +""", + re.VERBOSE, +) + + +Hstore: TypeAlias = Dict[str, Optional[str]] + + +class BaseHstoreDumper(RecursiveDumper): + def dump(self, obj: Hstore) -> Buffer: + if not obj: + return b"" + + tokens: List[str] = [] + + def add_token(s: str) -> None: + tokens.append('"') + tokens.append(_re_escape.sub(r"\\\1", s)) + tokens.append('"') + + for k, v in obj.items(): + + if not isinstance(k, str): + raise e.DataError("hstore keys can only be strings") + add_token(k) + + tokens.append("=>") + + if v is None: + tokens.append("NULL") + elif not isinstance(v, str): + raise e.DataError("hstore keys can only be strings") + else: + add_token(v) + + tokens.append(",") + + del tokens[-1] + data = "".join(tokens) + dumper = self._tx.get_dumper(data, PyFormat.TEXT) + return dumper.dump(data) + + +class HstoreLoader(RecursiveLoader): + def load(self, data: Buffer) -> Hstore: + loader = self._tx.get_loader(TEXT_OID, self.format) + s: str = loader.load(data) + + rv: Hstore = {} + start = 0 + for m in _re_hstore.finditer(s): + if m is None or m.start() != start: + raise e.DataError(f"error parsing hstore pair at char {start}") + k = _re_unescape.sub(r"\1", m.group(1)) + v = m.group(2) + if v is not None: + v = _re_unescape.sub(r"\1", v) + + rv[k] = v + start = m.end() + + if start < len(s): + raise e.DataError(f"error parsing hstore: unparsed data after char {start}") + + return rv + + +def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + """Register the adapters to load and dump hstore. + + :param info: The object with the information about the hstore type. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the 'hstore' extension loaded?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # Generate and register a customized text dumper + class HstoreDumper(BaseHstoreDumper): + oid = info.oid + + adapters.register_dumper(dict, HstoreDumper) + + # register the text loader on the oid + adapters.register_loader(info.oid, HstoreLoader) diff --git a/psycopg/psycopg/types/json.py b/psycopg/psycopg/types/json.py new file mode 100644 index 0000000..a80e0e4 --- /dev/null +++ b/psycopg/psycopg/types/json.py @@ -0,0 +1,232 @@ +""" +Adapers for JSON types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import json +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +from .. import abc +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap +from ..errors import DataError + +JsonDumpsFunction = Callable[[Any], str] +JsonLoadsFunction = Callable[[Union[str, bytes]], Any] + + +def set_json_dumps( + dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None +) -> None: + """ + Set the JSON serialisation function to store JSON objects in the database. + + :param dumps: The dump function to use. + :type dumps: `!Callable[[Any], str]` + :param context: Where to use the `!dumps` function. If not specified, use it + globally. + :type context: `~psycopg.Connection` or `~psycopg.Cursor` + + By default dumping JSON uses the builtin `json.dumps`. You can override + it to use a different JSON library or to use customised arguments. + + If the `Json` wrapper specified a `!dumps` function, use it in precedence + of the one set by this function. + """ + if context is None: + # If changing load function globally, just change the default on the + # global class + _JsonDumper._dumps = dumps + else: + adapters = context.adapters + + # If the scope is smaller than global, create subclassess and register + # them in the appropriate scope. + grid = [ + (Json, PyFormat.BINARY), + (Json, PyFormat.TEXT), + (Jsonb, PyFormat.BINARY), + (Jsonb, PyFormat.TEXT), + ] + dumper: Type[_JsonDumper] + for wrapper, format in grid: + base = _get_current_dumper(adapters, wrapper, format) + name = base.__name__ + if not base.__name__.startswith("Custom"): + name = f"Custom{name}" + dumper = type(name, (base,), {"_dumps": dumps}) + adapters.register_dumper(wrapper, dumper) + + +def set_json_loads( + loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None +) -> None: + """ + Set the JSON parsing function to fetch JSON objects from the database. + + :param loads: The load function to use. + :type loads: `!Callable[[bytes], Any]` + :param context: Where to use the `!loads` function. If not specified, use + it globally. + :type context: `~psycopg.Connection` or `~psycopg.Cursor` + + By default loading JSON uses the builtin `json.loads`. You can override + it to use a different JSON library or to use customised arguments. + """ + if context is None: + # If changing load function globally, just change the default on the + # global class + _JsonLoader._loads = loads + else: + # If the scope is smaller than global, create subclassess and register + # them in the appropriate scope. + grid = [ + ("json", JsonLoader), + ("json", JsonBinaryLoader), + ("jsonb", JsonbLoader), + ("jsonb", JsonbBinaryLoader), + ] + loader: Type[_JsonLoader] + for tname, base in grid: + loader = type(f"Custom{base.__name__}", (base,), {"_loads": loads}) + context.adapters.register_loader(tname, loader) + + +class _JsonWrapper: + __slots__ = ("obj", "dumps") + + def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None): + self.obj = obj + self.dumps = dumps + + def __repr__(self) -> str: + sobj = repr(self.obj) + if len(sobj) > 40: + sobj = f"{sobj[:35]} ... ({len(sobj)} chars)" + return f"{self.__class__.__name__}({sobj})" + + +class Json(_JsonWrapper): + __slots__ = () + + +class Jsonb(_JsonWrapper): + __slots__ = () + + +class _JsonDumper(Dumper): + + # The globally used JSON dumps() function. It can be changed globally (by + # set_json_dumps) or by a subclass. + _dumps: JsonDumpsFunction = json.dumps + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + super().__init__(cls, context) + self.dumps = self.__class__._dumps + + def dump(self, obj: _JsonWrapper) -> bytes: + dumps = obj.dumps or self.dumps + return dumps(obj.obj).encode() + + +class JsonDumper(_JsonDumper): + + oid = postgres.types["json"].oid + + +class JsonBinaryDumper(_JsonDumper): + + format = Format.BINARY + oid = postgres.types["json"].oid + + +class JsonbDumper(_JsonDumper): + + oid = postgres.types["jsonb"].oid + + +class JsonbBinaryDumper(_JsonDumper): + + format = Format.BINARY + oid = postgres.types["jsonb"].oid + + def dump(self, obj: _JsonWrapper) -> bytes: + dumps = obj.dumps or self.dumps + return b"\x01" + dumps(obj.obj).encode() + + +class _JsonLoader(Loader): + + # The globally used JSON loads() function. It can be changed globally (by + # set_json_loads) or by a subclass. + _loads: JsonLoadsFunction = json.loads + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + super().__init__(oid, context) + self.loads = self.__class__._loads + + def load(self, data: Buffer) -> Any: + # json.loads() cannot work on memoryview. + if not isinstance(data, bytes): + data = bytes(data) + return self.loads(data) + + +class JsonLoader(_JsonLoader): + pass + + +class JsonbLoader(_JsonLoader): + pass + + +class JsonBinaryLoader(_JsonLoader): + format = Format.BINARY + + +class JsonbBinaryLoader(_JsonLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Any: + if data and data[0] != 1: + raise DataError("unknown jsonb binary format: {data[0]}") + data = data[1:] + if not isinstance(data, bytes): + data = bytes(data) + return self.loads(data) + + +def _get_current_dumper( + adapters: AdaptersMap, cls: type, format: PyFormat +) -> Type[abc.Dumper]: + try: + return adapters.get_dumper(cls, format) + except e.ProgrammingError: + return _default_dumpers[cls, format] + + +_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = { + (Json, PyFormat.BINARY): JsonBinaryDumper, + (Json, PyFormat.TEXT): JsonDumper, + (Jsonb, PyFormat.BINARY): JsonbBinaryDumper, + (Jsonb, PyFormat.TEXT): JsonDumper, +} + + +def register_default_adapters(context: abc.AdaptContext) -> None: + adapters = context.adapters + + # Currently json binary format is nothing different than text, maybe with + # an extra memcopy we can avoid. + adapters.register_dumper(Json, JsonBinaryDumper) + adapters.register_dumper(Json, JsonDumper) + adapters.register_dumper(Jsonb, JsonbBinaryDumper) + adapters.register_dumper(Jsonb, JsonbDumper) + adapters.register_loader("json", JsonLoader) + adapters.register_loader("jsonb", JsonbLoader) + adapters.register_loader("json", JsonBinaryLoader) + adapters.register_loader("jsonb", JsonbBinaryLoader) diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py new file mode 100644 index 0000000..3eaa7f1 --- /dev/null +++ b/psycopg/psycopg/types/multirange.py @@ -0,0 +1,514 @@ +""" +Support for multirange types adaptation. +""" + +# Copyright (C) 2021 The Psycopg Team + +from decimal import Decimal +from typing import Any, Generic, List, Iterable +from typing import MutableSequence, Optional, Type, Union, overload +from datetime import date, datetime + +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext, Buffer, Dumper, DumperKey +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._struct import pack_len, unpack_len +from ..postgres import INVALID_OID, TEXT_OID +from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here + +from .range import Range, T, load_range_text, load_range_binary +from .range import dump_range_text, dump_range_binary, fail_dump + + +class Multirange(MutableSequence[Range[T]]): + """Python representation for a PostgreSQL multirange type. + + :param items: Sequence of ranges to initialise the object. + """ + + def __init__(self, items: Iterable[Range[T]] = ()): + self._ranges: List[Range[T]] = list(map(self._check_type, items)) + + def _check_type(self, item: Any) -> Range[Any]: + if not isinstance(item, Range): + raise TypeError( + f"Multirange is a sequence of Range, got {type(item).__name__}" + ) + return item + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._ranges!r})" + + def __str__(self) -> str: + return f"{{{', '.join(map(str, self._ranges))}}}" + + @overload + def __getitem__(self, index: int) -> Range[T]: + ... + + @overload + def __getitem__(self, index: slice) -> "Multirange[T]": + ... + + def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]": + if isinstance(index, int): + return self._ranges[index] + else: + return Multirange(self._ranges[index]) + + def __len__(self) -> int: + return len(self._ranges) + + @overload + def __setitem__(self, index: int, value: Range[T]) -> None: + ... + + @overload + def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: + ... + + def __setitem__( + self, + index: Union[int, slice], + value: Union[Range[T], Iterable[Range[T]]], + ) -> None: + if isinstance(index, int): + self._check_type(value) + self._ranges[index] = self._check_type(value) + elif not isinstance(value, Iterable): + raise TypeError("can only assign an iterable") + else: + value = map(self._check_type, value) + self._ranges[index] = value + + def __delitem__(self, index: Union[int, slice]) -> None: + del self._ranges[index] + + def insert(self, index: int, value: Range[T]) -> None: + self._ranges.insert(index, self._check_type(value)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return False + return self._ranges == other._ranges + + # Order is arbitrary but consistent + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return NotImplemented + return self._ranges < other._ranges + + def __le__(self, other: Any) -> bool: + return self == other or self < other # type: ignore + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return NotImplemented + return self._ranges > other._ranges + + def __ge__(self, other: Any) -> bool: + return self == other or self > other # type: ignore + + +# Subclasses to specify a specific subtype. Usually not needed + + +class Int4Multirange(Multirange[int]): + pass + + +class Int8Multirange(Multirange[int]): + pass + + +class NumericMultirange(Multirange[Decimal]): + pass + + +class DateMultirange(Multirange[date]): + pass + + +class TimestampMultirange(Multirange[datetime]): + pass + + +class TimestamptzMultirange(Multirange[datetime]): + pass + + +class BaseMultirangeDumper(RecursiveDumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + self._adapt_format = PyFormat.from_pq(self.format) + + def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey: + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self.cls + + item = self._get_item(obj) + if item is not None: + sd = self._tx.get_dumper(item, self._adapt_format) + return (self.cls, sd.get_key(item, format)) + else: + return (self.cls,) + + def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self + + item = self._get_item(obj) + if item is None: + return MultirangeDumper(self.cls) + + dumper: BaseMultirangeDumper + if type(item) is int: + # postgres won't cast int4range -> int8range so we must use + # text format and unknown oid here + sd = self._tx.get_dumper(item, PyFormat.TEXT) + dumper = MultirangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = INVALID_OID + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): + # Work around the normal mapping where text is dumped as unknown + dumper.oid = self._get_multirange_oid(TEXT_OID) + else: + dumper.oid = self._get_multirange_oid(sd.oid) + + return dumper + + def _get_item(self, obj: Multirange[Any]) -> Any: + """ + Return a member representative of the multirange + """ + for r in obj: + if r.lower is not None: + return r.lower + if r.upper is not None: + return r.upper + return None + + def _get_multirange_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + """ + info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid) + return info.oid if info else INVALID_OID + + +class MultirangeDumper(BaseMultirangeDumper): + """ + Dumper for multirange types. + + The dumper can upgrade to one specific for a different range type. + """ + + def dump(self, obj: Multirange[Any]) -> Buffer: + if not obj: + return b"{}" + + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out: List[Buffer] = [b"{"] + for r in obj: + out.append(dump_range_text(r, dump)) + out.append(b",") + out[-1] = b"}" + return b"".join(out) + + +class MultirangeBinaryDumper(BaseMultirangeDumper): + + format = Format.BINARY + + def dump(self, obj: Multirange[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out: List[Buffer] = [pack_len(len(obj))] + for r in obj: + data = dump_range_binary(r, dump) + out.append(pack_len(len(data))) + out.append(data) + return b"".join(out) + + +class BaseMultirangeLoader(RecursiveLoader, Generic[T]): + + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load + + +class MultirangeLoader(BaseMultirangeLoader[T]): + def load(self, data: Buffer) -> Multirange[T]: + if not data or data[0] != _START_INT: + raise e.DataError( + "malformed multirange starting with" + f" {bytes(data[:1]).decode('utf8', 'replace')}" + ) + + out = Multirange[T]() + if data == b"{}": + return out + + pos = 1 + data = data[pos:] + try: + while True: + r, pos = load_range_text(data, self._load) + out.append(r) + + sep = data[pos] # can raise IndexError + if sep == _SEP_INT: + data = data[pos + 1 :] + continue + elif sep == _END_INT: + if len(data) == pos + 1: + return out + else: + raise e.DataError( + "malformed multirange: data after closing brace" + ) + else: + raise e.DataError( + f"malformed multirange: found unexpected {chr(sep)}" + ) + + except IndexError: + raise e.DataError("malformed multirange: separator missing") + + return out + + +_SEP_INT = ord(",") +_START_INT = ord("{") +_END_INT = ord("}") + + +class MultirangeBinaryLoader(BaseMultirangeLoader[T]): + + format = Format.BINARY + + def load(self, data: Buffer) -> Multirange[T]: + nelems = unpack_len(data, 0)[0] + pos = 4 + out = Multirange[T]() + for i in range(nelems): + length = unpack_len(data, pos)[0] + pos += 4 + out.append(load_range_binary(data[pos : pos + length], self._load)) + pos += length + + if pos != len(data): + raise e.DataError("unexpected trailing data in multirange") + + return out + + +def register_multirange( + info: MultirangeInfo, context: Optional[AdaptContext] = None +) -> None: + """Register the adapters to load and dump a multirange type. + + :param info: The object with the information about the range to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + Register loaders so that loading data of this type will result in a `Range` + with bounds parsed as the right subtype. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested multirange available?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[MultirangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (MultirangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + bloader: Type[MultirangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (MultirangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, bloader) + + +# Text dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeDumper(MultirangeDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeDumper(MultirangeDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeDumper(MultirangeDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeDumper(MultirangeDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeDumper(MultirangeDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeDumper(MultirangeDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Binary dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Text loaders for builtin multirange types + + +class Int4MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeLoader(MultirangeLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeLoader(MultirangeLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +# Binary loaders for builtin multirange types + + +class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Multirange, MultirangeBinaryDumper) + adapters.register_dumper(Multirange, MultirangeDumper) + adapters.register_dumper(Int4Multirange, Int4MultirangeDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeDumper) + adapters.register_dumper(DateMultirange, DateMultirangeDumper) + adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper) + adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper) + adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper) + adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper) + adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper) + adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper) + adapters.register_loader("int4multirange", Int4MultirangeLoader) + adapters.register_loader("int8multirange", Int8MultirangeLoader) + adapters.register_loader("nummultirange", NumericMultirangeLoader) + adapters.register_loader("datemultirange", DateMultirangeLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeLoader) + adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader) + adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader) + adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader) + adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader) + adapters.register_loader("datemultirange", DateMultirangeBinaryLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader) + adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader) diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py new file mode 100644 index 0000000..2f2c05b --- /dev/null +++ b/psycopg/psycopg/types/net.py @@ -0,0 +1,206 @@ +""" +Adapters for network types. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Callable, Optional, Type, Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + +if TYPE_CHECKING: + import ipaddress + +Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"] +Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"] +Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"] + +# These objects will be imported lazily +ip_address: Callable[[str], Address] = None # type: ignore[assignment] +ip_interface: Callable[[str], Interface] = None # type: ignore[assignment] +ip_network: Callable[[str], Network] = None # type: ignore[assignment] +IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment] +IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment] +IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment] +IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment] +IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment] +IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment] + +PGSQL_AF_INET = 2 +PGSQL_AF_INET6 = 3 +IPV4_PREFIXLEN = 32 +IPV6_PREFIXLEN = 128 + + +class _LazyIpaddress: + def _ensure_module(self) -> None: + global ip_address, ip_interface, ip_network + global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface + global IPv4Network, IPv6Network + + if ip_address is None: + from ipaddress import ip_address, ip_interface, ip_network + from ipaddress import IPv4Address, IPv6Address + from ipaddress import IPv4Interface, IPv6Interface + from ipaddress import IPv4Network, IPv6Network + + +class InterfaceDumper(Dumper): + + oid = postgres.types["inet"].oid + + def dump(self, obj: Interface) -> bytes: + return str(obj).encode() + + +class NetworkDumper(Dumper): + + oid = postgres.types["cidr"].oid + + def dump(self, obj: Network) -> bytes: + return str(obj).encode() + + +class _AIBinaryDumper(Dumper): + format = Format.BINARY + oid = postgres.types["inet"].oid + + +class AddressBinaryDumper(_AIBinaryDumper): + def dump(self, obj: Address) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.max_prefixlen, 0, len(packed))) + return head + packed + + +class InterfaceBinaryDumper(_AIBinaryDumper): + def dump(self, obj: Interface) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.network.prefixlen, 0, len(packed))) + return head + packed + + +class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress): + """Either an address or an interface to inet + + Used when looking up by oid. + """ + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._ensure_module() + + def dump(self, obj: Union[Address, Interface]) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + if isinstance(obj, (IPv4Interface, IPv6Interface)): + prefixlen = obj.network.prefixlen + else: + prefixlen = obj.max_prefixlen + + head = bytes((family, prefixlen, 0, len(packed))) + return head + packed + + +class NetworkBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["cidr"].oid + + def dump(self, obj: Network) -> bytes: + packed = obj.network_address.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.prefixlen, 1, len(packed))) + return head + packed + + +class _LazyIpaddressLoader(Loader, _LazyIpaddress): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._ensure_module() + + +class InetLoader(_LazyIpaddressLoader): + def load(self, data: Buffer) -> Union[Address, Interface]: + if isinstance(data, memoryview): + data = bytes(data) + + if b"/" in data: + return ip_interface(data.decode()) + else: + return ip_address(data.decode()) + + +class InetBinaryLoader(_LazyIpaddressLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Union[Address, Interface]: + if isinstance(data, memoryview): + data = bytes(data) + + prefix = data[1] + packed = data[4:] + if data[0] == PGSQL_AF_INET: + if prefix == IPV4_PREFIXLEN: + return IPv4Address(packed) + else: + return IPv4Interface((packed, prefix)) + else: + if prefix == IPV6_PREFIXLEN: + return IPv6Address(packed) + else: + return IPv6Interface((packed, prefix)) + + +class CidrLoader(_LazyIpaddressLoader): + def load(self, data: Buffer) -> Network: + if isinstance(data, memoryview): + data = bytes(data) + + return ip_network(data.decode()) + + +class CidrBinaryLoader(_LazyIpaddressLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Network: + if isinstance(data, memoryview): + data = bytes(data) + + prefix = data[1] + packed = data[4:] + if data[0] == PGSQL_AF_INET: + return IPv4Network((packed, prefix)) + else: + return IPv6Network((packed, prefix)) + + return ip_network(data.decode()) + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper) + adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper) + adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper) + adapters.register_dumper(None, InetBinaryDumper) + adapters.register_loader("inet", InetLoader) + adapters.register_loader("inet", InetBinaryLoader) + adapters.register_loader("cidr", CidrLoader) + adapters.register_loader("cidr", CidrBinaryLoader) diff --git a/psycopg/psycopg/types/none.py b/psycopg/psycopg/types/none.py new file mode 100644 index 0000000..2ab857c --- /dev/null +++ b/psycopg/psycopg/types/none.py @@ -0,0 +1,25 @@ +""" +Adapters for None. +""" + +# Copyright (C) 2020 The Psycopg Team + +from ..abc import AdaptContext, NoneType +from ..adapt import Dumper + + +class NoneDumper(Dumper): + """ + Not a complete dumper as it doesn't implement dump(), but it implements + quote(), so it can be used in sql composition. + """ + + def dump(self, obj: None) -> bytes: + raise NotImplementedError("NULL is passed to Postgres in other ways") + + def quote(self, obj: None) -> bytes: + return b"NULL" + + +def register_default_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(NoneType, NoneDumper) diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py new file mode 100644 index 0000000..1bd9329 --- /dev/null +++ b/psycopg/psycopg/types/numeric.py @@ -0,0 +1,515 @@ +""" +Adapers for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import struct +from math import log +from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast +from decimal import Decimal, DefaultContext, Context + +from .. import postgres +from .. import errors as e +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader, PyFormat +from .._struct import pack_int2, pack_uint2, unpack_int2 +from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4 +from .._struct import pack_int8, unpack_int8 +from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8 + +# Exposed here +from .._wrappers import ( + Int2 as Int2, + Int4 as Int4, + Int8 as Int8, + IntNumeric as IntNumeric, + Oid as Oid, + Float4 as Float4, + Float8 as Float8, +) + + +class _IntDumper(Dumper): + def dump(self, obj: Any) -> Buffer: + t = type(obj) + if t is not int: + # Convert to int in order to dump IntEnum correctly + if issubclass(t, int): + obj = int(obj) + else: + raise e.DataError(f"integer expected, got {type(obj).__name__!r}") + + return str(obj).encode() + + def quote(self, obj: Any) -> Buffer: + value = self.dump(obj) + return value if obj >= 0 else b" " + value + + +class _SpecialValuesDumper(Dumper): + + _special: Dict[bytes, bytes] = {} + + def dump(self, obj: Any) -> bytes: + return str(obj).encode() + + def quote(self, obj: Any) -> bytes: + value = self.dump(obj) + + if value in self._special: + return self._special[value] + + return value if obj >= 0 else b" " + value + + +class FloatDumper(_SpecialValuesDumper): + + oid = postgres.types["float8"].oid + + _special = { + b"inf": b"'Infinity'::float8", + b"-inf": b"'-Infinity'::float8", + b"nan": b"'NaN'::float8", + } + + +class Float4Dumper(FloatDumper): + oid = postgres.types["float4"].oid + + +class FloatBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["float8"].oid + + def dump(self, obj: float) -> bytes: + return pack_float8(obj) + + +class Float4BinaryDumper(FloatBinaryDumper): + + oid = postgres.types["float4"].oid + + def dump(self, obj: float) -> bytes: + return pack_float4(obj) + + +class DecimalDumper(_SpecialValuesDumper): + + oid = postgres.types["numeric"].oid + + def dump(self, obj: Decimal) -> bytes: + if obj.is_nan(): + # cover NaN and sNaN + return b"NaN" + else: + return str(obj).encode() + + _special = { + b"Infinity": b"'Infinity'::numeric", + b"-Infinity": b"'-Infinity'::numeric", + b"NaN": b"'NaN'::numeric", + } + + +class Int2Dumper(_IntDumper): + oid = postgres.types["int2"].oid + + +class Int4Dumper(_IntDumper): + oid = postgres.types["int4"].oid + + +class Int8Dumper(_IntDumper): + oid = postgres.types["int8"].oid + + +class IntNumericDumper(_IntDumper): + oid = postgres.types["numeric"].oid + + +class OidDumper(_IntDumper): + oid = postgres.types["oid"].oid + + +class IntDumper(Dumper): + def dump(self, obj: Any) -> bytes: + raise TypeError( + f"{type(self).__name__} is a dispatcher to other dumpers:" + " dump() is not supposed to be called" + ) + + def get_key(self, obj: int, format: PyFormat) -> type: + return self.upgrade(obj, format).cls + + _int2_dumper = Int2Dumper(Int2) + _int4_dumper = Int4Dumper(Int4) + _int8_dumper = Int8Dumper(Int8) + _int_numeric_dumper = IntNumericDumper(IntNumeric) + + def upgrade(self, obj: int, format: PyFormat) -> Dumper: + if -(2**31) <= obj < 2**31: + if -(2**15) <= obj < 2**15: + return self._int2_dumper + else: + return self._int4_dumper + else: + if -(2**63) <= obj < 2**63: + return self._int8_dumper + else: + return self._int_numeric_dumper + + +class Int2BinaryDumper(Int2Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int2(obj) + + +class Int4BinaryDumper(Int4Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int4(obj) + + +class Int8BinaryDumper(Int8Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int8(obj) + + +# Ratio between number of bits required to store a number and number of pg +# decimal digits required. +BIT_PER_PGDIGIT = log(2) / log(10_000) + + +class IntNumericBinaryDumper(IntNumericDumper): + + format = Format.BINARY + + def dump(self, obj: int) -> Buffer: + return dump_int_to_numeric_binary(obj) + + +class OidBinaryDumper(OidDumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_uint4(obj) + + +class IntBinaryDumper(IntDumper): + + format = Format.BINARY + + _int2_dumper = Int2BinaryDumper(Int2) + _int4_dumper = Int4BinaryDumper(Int4) + _int8_dumper = Int8BinaryDumper(Int8) + _int_numeric_dumper = IntNumericBinaryDumper(IntNumeric) + + +class IntLoader(Loader): + def load(self, data: Buffer) -> int: + # it supports bytes directly + return int(data) + + +class Int2BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int2(data)[0] + + +class Int4BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int4(data)[0] + + +class Int8BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int8(data)[0] + + +class OidBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_uint4(data)[0] + + +class FloatLoader(Loader): + def load(self, data: Buffer) -> float: + # it supports bytes directly + return float(data) + + +class Float4BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> float: + return unpack_float4(data)[0] + + +class Float8BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> float: + return unpack_float8(data)[0] + + +class NumericLoader(Loader): + def load(self, data: Buffer) -> Decimal: + if isinstance(data, memoryview): + data = bytes(data) + return Decimal(data.decode()) + + +DEC_DIGITS = 4 # decimal digits per Postgres "digit" +NUMERIC_POS = 0x0000 +NUMERIC_NEG = 0x4000 +NUMERIC_NAN = 0xC000 +NUMERIC_PINF = 0xD000 +NUMERIC_NINF = 0xF000 + +_decimal_special = { + NUMERIC_NAN: Decimal("NaN"), + NUMERIC_PINF: Decimal("Infinity"), + NUMERIC_NINF: Decimal("-Infinity"), +} + + +class _ContextMap(DefaultDict[int, Context]): + """ + Cache for decimal contexts to use when the precision requires it. + + Note: if the default context is used (prec=28) you can get an invalid + operation or a rounding to 0: + + - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000') + - Decimal(1000).shift(25) = Decimal('0') + - Decimal(1000).shift(30) raises InvalidOperation + """ + + def __missing__(self, key: int) -> Context: + val = Context(prec=key) + self[key] = val + return val + + +_contexts = _ContextMap() +for i in range(DefaultContext.prec): + _contexts[i] = DefaultContext + +_unpack_numeric_head = cast( + Callable[[Buffer], Tuple[int, int, int, int]], + struct.Struct("!HhHH").unpack_from, +) +_pack_numeric_head = cast( + Callable[[int, int, int, int], bytes], + struct.Struct("!HhHH").pack, +) + + +class NumericBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Decimal: + ndigits, weight, sign, dscale = _unpack_numeric_head(data) + if sign == NUMERIC_POS or sign == NUMERIC_NEG: + val = 0 + for i in range(8, len(data), 2): + val = val * 10_000 + data[i] * 0x100 + data[i + 1] + + shift = dscale - (ndigits - weight - 1) * DEC_DIGITS + ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale] + return ( + Decimal(val if sign == NUMERIC_POS else -val) + .scaleb(-dscale, ctx) + .shift(shift, ctx) + ) + else: + try: + return _decimal_special[sign] + except KeyError: + raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None + + +NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0) +NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0) +NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0) + + +class DecimalBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["numeric"].oid + + def dump(self, obj: Decimal) -> Buffer: + return dump_decimal_to_numeric_binary(obj) + + +class NumericDumper(DecimalDumper): + def dump(self, obj: Union[Decimal, int]) -> bytes: + if isinstance(obj, int): + return str(obj).encode() + else: + return super().dump(obj) + + +class NumericBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["numeric"].oid + + def dump(self, obj: Union[Decimal, int]) -> Buffer: + if isinstance(obj, int): + return dump_int_to_numeric_binary(obj) + else: + return dump_decimal_to_numeric_binary(obj) + + +def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]: + sign, digits, exp = obj.as_tuple() + if exp == "n" or exp == "N": # type: ignore[comparison-overlap] + return NUMERIC_NAN_BIN + elif exp == "F": # type: ignore[comparison-overlap] + return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN + + # Weights of py digits into a pg digit according to their positions. + # Starting with an index wi != 0 is equivalent to prepending 0's to + # the digits tuple, but without really changing it. + weights = (1000, 100, 10, 1) + wi = 0 + + ndigits = nzdigits = len(digits) + + # Find the last nonzero digit + while nzdigits > 0 and digits[nzdigits - 1] == 0: + nzdigits -= 1 + + if exp <= 0: + dscale = -exp + else: + dscale = 0 + # align the py digits to the pg digits if there's some py exponent + ndigits += exp % DEC_DIGITS + + if not nzdigits: + return _pack_numeric_head(0, 0, NUMERIC_POS, dscale) + + # Equivalent of 0-padding left to align the py digits to the pg digits + # but without changing the digits tuple. + mod = (ndigits - dscale) % DEC_DIGITS + if mod: + wi = DEC_DIGITS - mod + ndigits += wi + + tmp = nzdigits + wi + out = bytearray( + _pack_numeric_head( + tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits + (ndigits + exp) // DEC_DIGITS - 1, # weight + NUMERIC_NEG if sign else NUMERIC_POS, # sign + dscale, + ) + ) + + pgdigit = 0 + for i in range(nzdigits): + pgdigit += weights[wi] * digits[i] + wi += 1 + if wi >= DEC_DIGITS: + out += pack_uint2(pgdigit) + pgdigit = wi = 0 + + if pgdigit: + out += pack_uint2(pgdigit) + + return out + + +def dump_int_to_numeric_binary(obj: int) -> bytearray: + ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1 + out = bytearray(b"\x00\x00" * (ndigits + 4)) + if obj < 0: + sign = NUMERIC_NEG + obj = -obj + else: + sign = NUMERIC_POS + + out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0) + i = 8 + (ndigits - 1) * 2 + while obj: + rem = obj % 10_000 + obj //= 10_000 + out[i : i + 2] = pack_uint2(rem) + i -= 2 + + return out + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(int, IntDumper) + adapters.register_dumper(int, IntBinaryDumper) + adapters.register_dumper(float, FloatDumper) + adapters.register_dumper(float, FloatBinaryDumper) + adapters.register_dumper(Int2, Int2Dumper) + adapters.register_dumper(Int4, Int4Dumper) + adapters.register_dumper(Int8, Int8Dumper) + adapters.register_dumper(IntNumeric, IntNumericDumper) + adapters.register_dumper(Oid, OidDumper) + + # The binary dumper is currently some 30% slower, so default to text + # (see tests/scripts/testdec.py for a rough benchmark) + # Also, must be after IntNumericDumper + adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper) + adapters.register_dumper("decimal.Decimal", DecimalDumper) + + # Used only by oid, can take both int and Decimal as input + adapters.register_dumper(None, NumericBinaryDumper) + adapters.register_dumper(None, NumericDumper) + + adapters.register_dumper(Float4, Float4Dumper) + adapters.register_dumper(Float8, FloatDumper) + adapters.register_dumper(Int2, Int2BinaryDumper) + adapters.register_dumper(Int4, Int4BinaryDumper) + adapters.register_dumper(Int8, Int8BinaryDumper) + adapters.register_dumper(Oid, OidBinaryDumper) + adapters.register_dumper(Float4, Float4BinaryDumper) + adapters.register_dumper(Float8, FloatBinaryDumper) + adapters.register_loader("int2", IntLoader) + adapters.register_loader("int4", IntLoader) + adapters.register_loader("int8", IntLoader) + adapters.register_loader("oid", IntLoader) + adapters.register_loader("int2", Int2BinaryLoader) + adapters.register_loader("int4", Int4BinaryLoader) + adapters.register_loader("int8", Int8BinaryLoader) + adapters.register_loader("oid", OidBinaryLoader) + adapters.register_loader("float4", FloatLoader) + adapters.register_loader("float8", FloatLoader) + adapters.register_loader("float4", Float4BinaryLoader) + adapters.register_loader("float8", Float8BinaryLoader) + adapters.register_loader("numeric", NumericLoader) + adapters.register_loader("numeric", NumericBinaryLoader) diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py new file mode 100644 index 0000000..c418480 --- /dev/null +++ b/psycopg/psycopg/types/range.py @@ -0,0 +1,700 @@ +""" +Support for range types adaptation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple +from typing import cast +from decimal import Decimal +from datetime import date, datetime + +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext, Buffer, Dumper, DumperKey +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._struct import pack_len, unpack_len +from ..postgres import INVALID_OID, TEXT_OID +from .._typeinfo import RangeInfo as RangeInfo # exported here + +RANGE_EMPTY = 0x01 # range is empty +RANGE_LB_INC = 0x02 # lower bound is inclusive +RANGE_UB_INC = 0x04 # upper bound is inclusive +RANGE_LB_INF = 0x08 # lower bound is -infinity +RANGE_UB_INF = 0x10 # upper bound is +infinity + +_EMPTY_HEAD = bytes([RANGE_EMPTY]) + +T = TypeVar("T") + + +class Range(Generic[T]): + """Python representation for a PostgreSQL range type. + + :param lower: lower bound for the range. `!None` means unbound + :param upper: upper bound for the range. `!None` means unbound + :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``, + representing whether the lower or upper bounds are included + :param empty: if `!True`, the range is empty + + """ + + __slots__ = ("_lower", "_upper", "_bounds") + + def __init__( + self, + lower: Optional[T] = None, + upper: Optional[T] = None, + bounds: str = "[)", + empty: bool = False, + ): + if not empty: + if bounds not in ("[)", "(]", "()", "[]"): + raise ValueError("bound flags not valid: %r" % bounds) + + self._lower = lower + self._upper = upper + + # Make bounds consistent with infs + if lower is None and bounds[0] == "[": + bounds = "(" + bounds[1] + if upper is None and bounds[1] == "]": + bounds = bounds[0] + ")" + + self._bounds = bounds + else: + self._lower = self._upper = None + self._bounds = "" + + def __repr__(self) -> str: + if self._bounds: + args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}" + else: + args = "empty=True" + + return f"{self.__class__.__name__}({args})" + + def __str__(self) -> str: + if not self._bounds: + return "empty" + + items = [ + self._bounds[0], + str(self._lower), + ", ", + str(self._upper), + self._bounds[1], + ] + return "".join(items) + + @property + def lower(self) -> Optional[T]: + """The lower bound of the range. `!None` if empty or unbound.""" + return self._lower + + @property + def upper(self) -> Optional[T]: + """The upper bound of the range. `!None` if empty or unbound.""" + return self._upper + + @property + def bounds(self) -> str: + """The bounds string (two characters from '[', '(', ']', ')').""" + return self._bounds + + @property + def isempty(self) -> bool: + """`!True` if the range is empty.""" + return not self._bounds + + @property + def lower_inf(self) -> bool: + """`!True` if the range doesn't have a lower bound.""" + if not self._bounds: + return False + return self._lower is None + + @property + def upper_inf(self) -> bool: + """`!True` if the range doesn't have an upper bound.""" + if not self._bounds: + return False + return self._upper is None + + @property + def lower_inc(self) -> bool: + """`!True` if the lower bound is included in the range.""" + if not self._bounds or self._lower is None: + return False + return self._bounds[0] == "[" + + @property + def upper_inc(self) -> bool: + """`!True` if the upper bound is included in the range.""" + if not self._bounds or self._upper is None: + return False + return self._bounds[1] == "]" + + def __contains__(self, x: T) -> bool: + if not self._bounds: + return False + + if self._lower is not None: + if self._bounds[0] == "[": + # It doesn't seem that Python has an ABC for ordered types. + if x < self._lower: # type: ignore[operator] + return False + else: + if x <= self._lower: # type: ignore[operator] + return False + + if self._upper is not None: + if self._bounds[1] == "]": + if x > self._upper: # type: ignore[operator] + return False + else: + if x >= self._upper: # type: ignore[operator] + return False + + return True + + def __bool__(self) -> bool: + return bool(self._bounds) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Range): + return False + return ( + self._lower == other._lower + and self._upper == other._upper + and self._bounds == other._bounds + ) + + def __hash__(self) -> int: + return hash((self._lower, self._upper, self._bounds)) + + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Range): + return NotImplemented + for attr in ("_lower", "_upper", "_bounds"): + self_value = getattr(self, attr) + other_value = getattr(other, attr) + if self_value == other_value: + pass + elif self_value is None: + return True + elif other_value is None: + return False + else: + return cast(bool, self_value < other_value) + return False + + def __le__(self, other: Any) -> bool: + return self == other or self < other # type: ignore + + def __gt__(self, other: Any) -> bool: + if isinstance(other, Range): + return other < self + else: + return NotImplemented + + def __ge__(self, other: Any) -> bool: + return self == other or self > other # type: ignore + + def __getstate__(self) -> Dict[str, Any]: + return { + slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot) + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + for slot, value in state.items(): + setattr(self, slot, value) + + +# Subclasses to specify a specific subtype. Usually not needed: only needed +# in binary copy, where switching to text is not an option. + + +class Int4Range(Range[int]): + pass + + +class Int8Range(Range[int]): + pass + + +class NumericRange(Range[Decimal]): + pass + + +class DateRange(Range[date]): + pass + + +class TimestampRange(Range[datetime]): + pass + + +class TimestamptzRange(Range[datetime]): + pass + + +class BaseRangeDumper(RecursiveDumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + self._adapt_format = PyFormat.from_pq(self.format) + + def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey: + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Range: + return self.cls + + item = self._get_item(obj) + if item is not None: + sd = self._tx.get_dumper(item, self._adapt_format) + return (self.cls, sd.get_key(item, format)) + else: + return (self.cls,) + + def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Range: + return self + + item = self._get_item(obj) + if item is None: + return RangeDumper(self.cls) + + dumper: BaseRangeDumper + if type(item) is int: + # postgres won't cast int4range -> int8range so we must use + # text format and unknown oid here + sd = self._tx.get_dumper(item, PyFormat.TEXT) + dumper = RangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = INVALID_OID + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): + # Work around the normal mapping where text is dumped as unknown + dumper.oid = self._get_range_oid(TEXT_OID) + else: + dumper.oid = self._get_range_oid(sd.oid) + + return dumper + + def _get_item(self, obj: Range[Any]) -> Any: + """ + Return a member representative of the range + """ + rv = obj.lower + return rv if rv is not None else obj.upper + + def _get_range_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + """ + info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid) + return info.oid if info else INVALID_OID + + +class RangeDumper(BaseRangeDumper): + """ + Dumper for range types. + + The dumper can upgrade to one specific for a different range type. + """ + + def dump(self, obj: Range[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + return dump_range_text(obj, dump) + + +def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer: + if obj.isempty: + return b"empty" + + parts: List[Buffer] = [b"[" if obj.lower_inc else b"("] + + def dump_item(item: Any) -> Buffer: + ad = dump(item) + if not ad: + return b'""' + elif _re_needs_quotes.search(ad): + return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"' + else: + return ad + + if obj.lower is not None: + parts.append(dump_item(obj.lower)) + + parts.append(b",") + + if obj.upper is not None: + parts.append(dump_item(obj.upper)) + + parts.append(b"]" if obj.upper_inc else b")") + + return b"".join(parts) + + +_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]') +_re_esc = re.compile(rb"([\\\"])") + + +class RangeBinaryDumper(BaseRangeDumper): + + format = Format.BINARY + + def dump(self, obj: Range[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + return dump_range_binary(obj, dump) + + +def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer: + if not obj: + return _EMPTY_HEAD + + out = bytearray([0]) # will replace the head later + + head = 0 + if obj.lower_inc: + head |= RANGE_LB_INC + if obj.upper_inc: + head |= RANGE_UB_INC + + if obj.lower is not None: + data = dump(obj.lower) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_LB_INF + + if obj.upper is not None: + data = dump(obj.upper) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_UB_INF + + out[0] = head + return out + + +def fail_dump(obj: Any) -> Buffer: + raise e.InternalError("trying to dump a range element without information") + + +class BaseRangeLoader(RecursiveLoader, Generic[T]): + """Generic loader for a range. + + Subclasses must specify the oid of the subtype and the class to load. + """ + + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load + + +class RangeLoader(BaseRangeLoader[T]): + def load(self, data: Buffer) -> Range[T]: + return load_range_text(data, self._load)[0] + + +def load_range_text( + data: Buffer, load: Callable[[Buffer], Any] +) -> Tuple[Range[Any], int]: + if data == b"empty": + return Range(empty=True), 5 + + m = _re_range.match(data) + if m is None: + raise e.DataError( + f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'" + ) + + lower = None + item = m.group(3) + if item is None: + item = m.group(2) + if item is not None: + lower = load(_re_undouble.sub(rb"\1", item)) + else: + lower = load(item) + + upper = None + item = m.group(5) + if item is None: + item = m.group(4) + if item is not None: + upper = load(_re_undouble.sub(rb"\1", item)) + else: + upper = load(item) + + bounds = (m.group(1) + m.group(6)).decode() + + return Range(lower, upper, bounds), m.end() + + +_re_range = re.compile( + rb""" + ( \(|\[ ) # lower bound flag + (?: # lower bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^",]+ ) # - or an unquoted string + )? # - or empty (not caught) + , + (?: # upper bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^"\)\]]+ ) # - or an unquoted string + )? # - or empty (not caught) + ( \)|\] ) # upper bound flag + """, + re.VERBOSE, +) + +_re_undouble = re.compile(rb'(["\\])\1') + + +class RangeBinaryLoader(BaseRangeLoader[T]): + + format = Format.BINARY + + def load(self, data: Buffer) -> Range[T]: + return load_range_binary(data, self._load) + + +def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]: + head = data[0] + if head & RANGE_EMPTY: + return Range(empty=True) + + lb = "[" if head & RANGE_LB_INC else "(" + ub = "]" if head & RANGE_UB_INC else ")" + + pos = 1 # after the head + if head & RANGE_LB_INF: + min = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + min = load(data[pos : pos + length]) + pos += length + + if head & RANGE_UB_INF: + max = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + max = load(data[pos : pos + length]) + pos += length + + return Range(min, max, lb + ub) + + +def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None: + """Register the adapters to load and dump a range type. + + :param info: The object with the information about the range to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + Register loaders so that loading data of this type will result in a `Range` + with bounds parsed as the right subtype. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested range available?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[RangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (RangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + bloader: Type[RangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (RangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, bloader) + + +# Text dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeDumper(RangeDumper): + oid = postgres.types["int4range"].oid + + +class Int8RangeDumper(RangeDumper): + oid = postgres.types["int8range"].oid + + +class NumericRangeDumper(RangeDumper): + oid = postgres.types["numrange"].oid + + +class DateRangeDumper(RangeDumper): + oid = postgres.types["daterange"].oid + + +class TimestampRangeDumper(RangeDumper): + oid = postgres.types["tsrange"].oid + + +class TimestamptzRangeDumper(RangeDumper): + oid = postgres.types["tstzrange"].oid + + +# Binary dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["int4range"].oid + + +class Int8RangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["int8range"].oid + + +class NumericRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["numrange"].oid + + +class DateRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["daterange"].oid + + +class TimestampRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["tsrange"].oid + + +class TimestamptzRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["tstzrange"].oid + + +# Text loaders for builtin range types + + +class Int4RangeLoader(RangeLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8RangeLoader(RangeLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericRangeLoader(RangeLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateRangeLoader(RangeLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampRangeLoader(RangeLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZRangeLoader(RangeLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +# Binary loaders for builtin range types + + +class Int4RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateRangeBinaryLoader(RangeBinaryLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Range, RangeBinaryDumper) + adapters.register_dumper(Range, RangeDumper) + adapters.register_dumper(Int4Range, Int4RangeDumper) + adapters.register_dumper(Int8Range, Int8RangeDumper) + adapters.register_dumper(NumericRange, NumericRangeDumper) + adapters.register_dumper(DateRange, DateRangeDumper) + adapters.register_dumper(TimestampRange, TimestampRangeDumper) + adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper) + adapters.register_dumper(Int4Range, Int4RangeBinaryDumper) + adapters.register_dumper(Int8Range, Int8RangeBinaryDumper) + adapters.register_dumper(NumericRange, NumericRangeBinaryDumper) + adapters.register_dumper(DateRange, DateRangeBinaryDumper) + adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper) + adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper) + adapters.register_loader("int4range", Int4RangeLoader) + adapters.register_loader("int8range", Int8RangeLoader) + adapters.register_loader("numrange", NumericRangeLoader) + adapters.register_loader("daterange", DateRangeLoader) + adapters.register_loader("tsrange", TimestampRangeLoader) + adapters.register_loader("tstzrange", TimestampTZRangeLoader) + adapters.register_loader("int4range", Int4RangeBinaryLoader) + adapters.register_loader("int8range", Int8RangeBinaryLoader) + adapters.register_loader("numrange", NumericRangeBinaryLoader) + adapters.register_loader("daterange", DateRangeBinaryLoader) + adapters.register_loader("tsrange", TimestampRangeBinaryLoader) + adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader) diff --git a/psycopg/psycopg/types/shapely.py b/psycopg/psycopg/types/shapely.py new file mode 100644 index 0000000..e99f256 --- /dev/null +++ b/psycopg/psycopg/types/shapely.py @@ -0,0 +1,75 @@ +""" +Adapters for PostGIS geometries +""" + +from typing import Optional + +from .. import postgres +from ..abc import AdaptContext, Buffer +from ..adapt import Dumper, Loader +from ..pq import Format +from .._typeinfo import TypeInfo + + +try: + from shapely.wkb import loads, dumps + from shapely.geometry.base import BaseGeometry + +except ImportError: + raise ImportError( + "The module psycopg.types.shapely requires the package 'Shapely'" + " to be installed" + ) + + +class GeometryBinaryLoader(Loader): + format = Format.BINARY + + def load(self, data: Buffer) -> "BaseGeometry": + if not isinstance(data, bytes): + data = bytes(data) + return loads(data) + + +class GeometryLoader(Loader): + def load(self, data: Buffer) -> "BaseGeometry": + # it's a hex string in binary + if isinstance(data, memoryview): + data = bytes(data) + return loads(data.decode(), hex=True) + + +class BaseGeometryBinaryDumper(Dumper): + format = Format.BINARY + + def dump(self, obj: "BaseGeometry") -> bytes: + return dumps(obj) # type: ignore + + +class BaseGeometryDumper(Dumper): + def dump(self, obj: "BaseGeometry") -> bytes: + return dumps(obj, hex=True).encode() # type: ignore + + +def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + """Register Shapely dumper and loaders.""" + + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the 'postgis' extension loaded?") + + info.register(context) + adapters = context.adapters if context else postgres.adapters + + class GeometryDumper(BaseGeometryDumper): + oid = info.oid + + class GeometryBinaryDumper(BaseGeometryBinaryDumper): + oid = info.oid + + adapters.register_loader(info.oid, GeometryBinaryLoader) + adapters.register_loader(info.oid, GeometryLoader) + # Default binary dump + adapters.register_dumper(BaseGeometry, GeometryDumper) + adapters.register_dumper(BaseGeometry, GeometryBinaryDumper) diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py new file mode 100644 index 0000000..cd5360d --- /dev/null +++ b/psycopg/psycopg/types/string.py @@ -0,0 +1,239 @@ +""" +Adapters for textual types. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Optional, Union, TYPE_CHECKING + +from .. import postgres +from ..pq import Format, Escaping +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader +from ..errors import DataError +from .._encodings import conn_encoding + +if TYPE_CHECKING: + from ..pq.abc import Escaping as EscapingProto + + +class _BaseStrDumper(Dumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "utf-8" + + +class _StrBinaryDumper(_BaseStrDumper): + """ + Base class to dump a Python strings to a Postgres text type, in binary format. + + Subclasses shall specify the oids of real types (text, varchar, name...). + """ + + format = Format.BINARY + + def dump(self, obj: str) -> bytes: + # the server will raise DataError subclass if the string contains 0x00 + return obj.encode(self._encoding) + + +class _StrDumper(_BaseStrDumper): + """ + Base class to dump a Python strings to a Postgres text type, in text format. + + Subclasses shall specify the oids of real types (text, varchar, name...). + """ + + def dump(self, obj: str) -> bytes: + if "\x00" in obj: + raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes") + else: + return obj.encode(self._encoding) + + +# The next are concrete dumpers, each one specifying the oid they dump to. + + +class StrBinaryDumper(_StrBinaryDumper): + + oid = postgres.types["text"].oid + + +class StrBinaryDumperVarchar(_StrBinaryDumper): + + oid = postgres.types["varchar"].oid + + +class StrBinaryDumperName(_StrBinaryDumper): + + oid = postgres.types["name"].oid + + +class StrDumper(_StrDumper): + """ + Dumper for strings in text format to the text oid. + + Note that this dumper is not used by default because the type is too strict + and PostgreSQL would require an explicit casts to everything that is not a + text field. However it is useful where the unknown oid is ambiguous and the + text oid is required, for instance with variadic functions. + """ + + oid = postgres.types["text"].oid + + +class StrDumperVarchar(_StrDumper): + + oid = postgres.types["varchar"].oid + + +class StrDumperName(_StrDumper): + + oid = postgres.types["name"].oid + + +class StrDumperUnknown(_StrDumper): + """ + Dumper for strings in text format to the unknown oid. + + This dumper is the default dumper for strings and allows to use Python + strings to represent almost every data type. In a few places, however, the + unknown oid is not accepted (for instance in variadic functions such as + 'concat()'). In that case either a cast on the placeholder ('%s::text') or + the StrTextDumper should be used. + """ + + pass + + +class TextLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "" + + def load(self, data: Buffer) -> Union[bytes, str]: + if self._encoding: + if isinstance(data, memoryview): + data = bytes(data) + return data.decode(self._encoding) + else: + # return bytes for SQL_ASCII db + if not isinstance(data, bytes): + data = bytes(data) + return data + + +class TextBinaryLoader(TextLoader): + + format = Format.BINARY + + +class BytesDumper(Dumper): + + oid = postgres.types["bytea"].oid + _qprefix = b"" + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._esc = Escaping(self.connection.pgconn if self.connection else None) + + def dump(self, obj: Buffer) -> Buffer: + return self._esc.escape_bytea(obj) + + def quote(self, obj: Buffer) -> bytes: + escaped = self.dump(obj) + + # We cannot use the base quoting because escape_bytea already returns + # the quotes content. if scs is off it will escape the backslashes in + # the format, otherwise it won't, but it doesn't tell us what quotes to + # use. + if self.connection: + if not self._qprefix: + scs = self.connection.pgconn.parameter_status( + b"standard_conforming_strings" + ) + self._qprefix = b"'" if scs == b"on" else b" E'" + + return self._qprefix + escaped + b"'" + + # We don't have a connection, so someone is using us to generate a file + # to use off-line or something like that. PQescapeBytea, like its + # string counterpart, is not predictable whether it will escape + # backslashes. + rv: bytes = b" E'" + escaped + b"'" + if self._esc.escape_bytea(b"\x00") == b"\\000": + rv = rv.replace(b"\\", b"\\\\") + return rv + + +class BytesBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["bytea"].oid + + def dump(self, obj: Buffer) -> Buffer: + return obj + + +class ByteaLoader(Loader): + + _escaping: "EscapingProto" + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + if not hasattr(self.__class__, "_escaping"): + self.__class__._escaping = Escaping() + + def load(self, data: Buffer) -> bytes: + return self._escaping.unescape_bytea(data) + + +class ByteaBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Buffer: + return data + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + + # NOTE: the order the dumpers are registered is relevant. The last one + # registered becomes the default for each type. Usually, binary is the + # default dumper. For text we use the text dumper as default because it + # plays the role of unknown, and it can be cast automatically to other + # types. However, before that, we register dumper with 'text', 'varchar', + # 'name' oids, which will be used when a text dumper is looked up by oid. + adapters.register_dumper(str, StrBinaryDumperName) + adapters.register_dumper(str, StrBinaryDumperVarchar) + adapters.register_dumper(str, StrBinaryDumper) + adapters.register_dumper(str, StrDumperName) + adapters.register_dumper(str, StrDumperVarchar) + adapters.register_dumper(str, StrDumper) + adapters.register_dumper(str, StrDumperUnknown) + + adapters.register_loader(postgres.INVALID_OID, TextLoader) + adapters.register_loader("bpchar", TextLoader) + adapters.register_loader("name", TextLoader) + adapters.register_loader("text", TextLoader) + adapters.register_loader("varchar", TextLoader) + adapters.register_loader('"char"', TextLoader) + adapters.register_loader("bpchar", TextBinaryLoader) + adapters.register_loader("name", TextBinaryLoader) + adapters.register_loader("text", TextBinaryLoader) + adapters.register_loader("varchar", TextBinaryLoader) + adapters.register_loader('"char"', TextBinaryLoader) + + adapters.register_dumper(bytes, BytesDumper) + adapters.register_dumper(bytearray, BytesDumper) + adapters.register_dumper(memoryview, BytesDumper) + adapters.register_dumper(bytes, BytesBinaryDumper) + adapters.register_dumper(bytearray, BytesBinaryDumper) + adapters.register_dumper(memoryview, BytesBinaryDumper) + + adapters.register_loader("bytea", ByteaLoader) + adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader) + adapters.register_loader("bytea", ByteaBinaryLoader) diff --git a/psycopg/psycopg/types/uuid.py b/psycopg/psycopg/types/uuid.py new file mode 100644 index 0000000..f92354c --- /dev/null +++ b/psycopg/psycopg/types/uuid.py @@ -0,0 +1,65 @@ +""" +Adapters for the UUID type. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Callable, Optional, TYPE_CHECKING + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + +if TYPE_CHECKING: + import uuid + +# Importing the uuid module is slow, so import it only on request. +UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment] + + +class UUIDDumper(Dumper): + + oid = postgres.types["uuid"].oid + + def dump(self, obj: "uuid.UUID") -> bytes: + return obj.hex.encode() + + +class UUIDBinaryDumper(UUIDDumper): + + format = Format.BINARY + + def dump(self, obj: "uuid.UUID") -> bytes: + return obj.bytes + + +class UUIDLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + global UUID + if UUID is None: + from uuid import UUID + + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) + return UUID(data.decode()) + + +class UUIDBinaryLoader(UUIDLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) + return UUID(bytes=data) + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("uuid.UUID", UUIDDumper) + adapters.register_dumper("uuid.UUID", UUIDBinaryDumper) + adapters.register_loader("uuid", UUIDLoader) + adapters.register_loader("uuid", UUIDBinaryLoader) diff --git a/psycopg/psycopg/version.py b/psycopg/psycopg/version.py new file mode 100644 index 0000000..a98bc35 --- /dev/null +++ b/psycopg/psycopg/version.py @@ -0,0 +1,14 @@ +""" +psycopg distribution version file. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Use a versioning scheme as defined in +# https://www.python.org/dev/peps/pep-0440/ + +# STOP AND READ! if you change: +__version__ = "3.1.7" +# also change: +# - `docs/news.rst` to declare this as the current version or an unreleased one +# - `psycopg_c/psycopg_c/version.py` to the same version. diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py new file mode 100644 index 0000000..7abfc58 --- /dev/null +++ b/psycopg/psycopg/waiting.py @@ -0,0 +1,331 @@ +""" +Code concerned with waiting in different contexts (blocking, async, etc). + +These functions are designed to consume the generators returned by the +`generators` module function and to return their final value. + +""" + +# Copyright (C) 2020 The Psycopg Team + + +import os +import select +import selectors +from typing import Dict, Optional +from asyncio import get_event_loop, wait_for, Event, TimeoutError +from selectors import DefaultSelector + +from . import errors as e +from .abc import RV, PQGen, PQGenConn, WaitFunc +from ._enums import Wait as Wait, Ready as Ready # re-exported +from ._cmodule import _psycopg + +WAIT_R = Wait.R +WAIT_W = Wait.W +WAIT_RW = Wait.RW +READY_R = Ready.R +READY_W = Ready.W +READY_RW = Ready.RW + + +def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using the best strategy available. + + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. + :type timeout: float + :return: whatever `!gen` returns on completion. + + Consume `!gen`, scheduling `fileno` for completion when it is reported to + block. Once ready again send the ready state back to `!gen`. + """ + try: + s = next(gen) + with DefaultSelector() as sel: + while True: + sel.register(fileno, s) + rlist = None + while not rlist: + rlist = sel.select(timeout=timeout) + sel.unregister(fileno) + # note: this line should require a cast, but mypy doesn't complain + ready: Ready = rlist[0][1] + assert s & ready + s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: + """ + Wait for a connection generator using the best strategy available. + + :param gen: a generator performing database operations and yielding + (fd, `Ready`) pairs when it would block. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. If zero or None, wait indefinitely. + :type timeout: float + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. + """ + try: + fileno, s = next(gen) + if not timeout: + timeout = None + with DefaultSelector() as sel: + while True: + sel.register(fileno, s) + rlist = sel.select(timeout=timeout) + sel.unregister(fileno) + if not rlist: + raise e.ConnectionTimeout("connection timeout expired") + ready: Ready = rlist[0][1] # type: ignore[assignment] + fileno, s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +async def wait_async(gen: PQGen[RV], fileno: int) -> RV: + """ + Coroutine waiting for a generator to complete. + + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but exposing an `asyncio` interface. + """ + # Use an event to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + ev = Event() + loop = get_event_loop() + ready: Ready + s: Wait + + def wakeup(state: Ready) -> None: + nonlocal ready + ready |= state # type: ignore[assignment] + ev.set() + + try: + s = next(gen) + while True: + reader = s & WAIT_R + writer = s & WAIT_W + if not reader and not writer: + raise e.InternalError(f"bad poll status: {s}") + ev.clear() + ready = 0 # type: ignore[assignment] + if reader: + loop.add_reader(fileno, wakeup, READY_R) + if writer: + loop.add_writer(fileno, wakeup, READY_W) + try: + await ev.wait() + finally: + if reader: + loop.remove_reader(fileno) + if writer: + loop.remove_writer(fileno) + s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: + """ + Coroutine waiting for a connection generator to complete. + + :param gen: a generator performing database operations and yielding + (fd, `Ready`) pairs when it would block. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. If zero or None, wait indefinitely. + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. + """ + # Use an event to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + ev = Event() + loop = get_event_loop() + ready: Ready + s: Wait + + def wakeup(state: Ready) -> None: + nonlocal ready + ready = state + ev.set() + + try: + fileno, s = next(gen) + if not timeout: + timeout = None + while True: + reader = s & WAIT_R + writer = s & WAIT_W + if not reader and not writer: + raise e.InternalError(f"bad poll status: {s}") + ev.clear() + ready = 0 # type: ignore[assignment] + if reader: + loop.add_reader(fileno, wakeup, READY_R) + if writer: + loop.add_writer(fileno, wakeup, READY_W) + try: + await wait_for(ev.wait(), timeout) + finally: + if reader: + loop.remove_reader(fileno) + if writer: + loop.remove_writer(fileno) + fileno, s = gen.send(ready) + + except TimeoutError: + raise e.ConnectionTimeout("connection timeout expired") + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +# Specialised implementation of wait functions. + + +def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using select where supported. + """ + try: + s = next(gen) + + empty = () + fnlist = (fileno,) + while True: + rl, wl, xl = select.select( + fnlist if s & WAIT_R else empty, + fnlist if s & WAIT_W else empty, + fnlist, + timeout, + ) + ready = 0 + if rl: + ready = READY_R + if wl: + ready |= READY_W + if not ready: + continue + # assert s & ready + s = gen.send(ready) # type: ignore + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +poll_evmasks: Dict[Wait, int] + +if hasattr(selectors, "EpollSelector"): + poll_evmasks = { + WAIT_R: select.EPOLLONESHOT | select.EPOLLIN, + WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT, + WAIT_RW: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLOUT, + } +else: + poll_evmasks = {} + + +def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using epoll where supported. + + Parameters are like for `wait()`. If it is detected that the best selector + strategy is `epoll` then this function will be used instead of `wait`. + + See also: https://linux.die.net/man/2/epoll_ctl + """ + try: + s = next(gen) + + if timeout is None or timeout < 0: + timeout = 0 + else: + timeout = int(timeout * 1000.0) + + with select.epoll() as epoll: + evmask = poll_evmasks[s] + epoll.register(fileno, evmask) + while True: + fileevs = None + while not fileevs: + fileevs = epoll.poll(timeout) + ev = fileevs[0][1] + ready = 0 + if ev & ~select.EPOLLOUT: + ready = READY_R + if ev & ~select.EPOLLIN: + ready |= READY_W + # assert s & ready + s = gen.send(ready) + evmask = poll_evmasks[s] + epoll.modify(fileno, evmask) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +if _psycopg: + wait_c = _psycopg.wait_c + + +# Choose the best wait strategy for the platform. +# +# the selectors objects have a generic interface but come with some overhead, +# so we also offer more finely tuned implementations. + +wait: WaitFunc + +# Allow the user to choose a specific function for testing +if "PSYCOPG_WAIT_FUNC" in os.environ: + fname = os.environ["PSYCOPG_WAIT_FUNC"] + if not fname.startswith("wait_") or fname not in globals(): + raise ImportError( + "PSYCOPG_WAIT_FUNC should be the name of an available wait function;" + f" got {fname!r}" + ) + wait = globals()[fname] + +elif _psycopg: + wait = wait_c + +elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None): + # On Windows, SelectSelector should be the default. + wait = wait_select + +elif selectors.DefaultSelector is getattr(selectors, "EpollSelector", None): + # NOTE: select seems more performing than epoll. It is admittedly unlikely + # that a platform has epoll but not select, so maybe we could kill + # wait_epoll altogether(). More testing to do. + wait = wait_select if hasattr(selectors, "SelectSelector") else wait_epoll + +elif selectors.DefaultSelector is getattr(selectors, "KqueueSelector", None): + # wait_select is faster than wait_selector, probably because of less overhead + wait = wait_select if hasattr(selectors, "SelectSelector") else wait_selector + +else: + wait = wait_selector diff --git a/psycopg/pyproject.toml b/psycopg/pyproject.toml new file mode 100644 index 0000000..21e410c --- /dev/null +++ b/psycopg/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=49.2.0", "wheel>=0.37"] +build-backend = "setuptools.build_meta" diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg new file mode 100644 index 0000000..fdcb612 --- /dev/null +++ b/psycopg/setup.cfg @@ -0,0 +1,47 @@ +[metadata] +name = psycopg +description = PostgreSQL database adapter for Python +url = https://psycopg.org/psycopg3/ +author = Daniele Varrazzo +author_email = daniele.varrazzo@gmail.com +license = GNU Lesser General Public License v3 (LGPLv3) + +project_urls = + Homepage = https://psycopg.org/ + Code = https://github.com/psycopg/psycopg + Issue Tracker = https://github.com/psycopg/psycopg/issues + Download = https://pypi.org/project/psycopg/ + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) + Operating System :: MacOS :: MacOS X + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Database + Topic :: Database :: Front-Ends + Topic :: Software Development + Topic :: Software Development :: Libraries :: Python Modules + +long_description = file: README.rst +long_description_content_type = text/x-rst +license_files = LICENSE.txt + +[options] +python_requires = >= 3.7 +packages = find: +zip_safe = False +install_requires = + backports.zoneinfo >= 0.2.0; python_version < "3.9" + typing-extensions >= 4.1 + tzdata; sys_platform == "win32" + +[options.package_data] +psycopg = py.typed diff --git a/psycopg/setup.py b/psycopg/setup.py new file mode 100644 index 0000000..90d4380 --- /dev/null +++ b/psycopg/setup.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +PostgreSQL database adapter for Python - pure Python package +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +from setuptools import setup + +# Move to the directory of setup.py: executing this file from another location +# (e.g. from the project root) will fail +here = os.path.abspath(os.path.dirname(__file__)) +if os.path.abspath(os.getcwd()) != here: + os.chdir(here) + +# Only for release 3.1.7. Not building binary packages because Scaleway +# has no runner available, but psycopg-binary 3.1.6 should work as well +# as the only change is in rows.py. +version = "3.1.7" +ext_versions = ">= 3.1.6, <= 3.1.7" + +extras_require = { + # Install the C extension module (requires dev tools) + "c": [ + f"psycopg-c {ext_versions}", + ], + # Install the stand-alone C extension module + "binary": [ + f"psycopg-binary {ext_versions}", + ], + # Install the connection pool + "pool": [ + "psycopg-pool", + ], + # Requirements to run the test suite + "test": [ + "mypy >= 0.990", + "pproxy >= 2.7", + "pytest >= 6.2.5", + "pytest-asyncio >= 0.17", + "pytest-cov >= 3.0", + "pytest-randomly >= 3.10", + ], + # Requirements needed for development + "dev": [ + "black >= 22.3.0", + "dnspython >= 2.1", + "flake8 >= 4.0", + "mypy >= 0.990", + "types-setuptools >= 57.4", + "wheel >= 0.37", + ], + # Requirements needed to build the documentation + "docs": [ + "Sphinx >= 5.0", + "furo == 2022.6.21", + "sphinx-autobuild >= 2021.3.14", + "sphinx-autodoc-typehints >= 1.12", + ], +} + +setup( + version=version, + extras_require=extras_require, +) diff --git a/psycopg_c/.flake8 b/psycopg_c/.flake8 new file mode 100644 index 0000000..2ae629c --- /dev/null +++ b/psycopg_c/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 diff --git a/psycopg_c/LICENSE.txt b/psycopg_c/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/psycopg_c/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/psycopg_c/README-binary.rst b/psycopg_c/README-binary.rst new file mode 100644 index 0000000..9318d57 --- /dev/null +++ b/psycopg_c/README-binary.rst @@ -0,0 +1,29 @@ +Psycopg 3: PostgreSQL database adapter for Python - binary package +================================================================== + +This distribution contains the precompiled optimization package +``psycopg_binary``. + +You shouldn't install this package directly: use instead :: + + pip install "psycopg[binary]" + +to install a version of the optimization package matching the ``psycopg`` +version installed. + +Installing this package requires pip >= 20.3 or newer installed. + +This package is not available for every platform: check out `Binary +installation`__ in the documentation. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #binary-installation + +Please read `the project readme`__ and `the installation documentation`__ for +more details. + +.. __: https://github.com/psycopg/psycopg#readme +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg_c/README.rst b/psycopg_c/README.rst new file mode 100644 index 0000000..de9ba93 --- /dev/null +++ b/psycopg_c/README.rst @@ -0,0 +1,33 @@ +Psycopg 3: PostgreSQL database adapter for Python - optimisation package +======================================================================== + +This distribution contains the optional optimization package ``psycopg_c``. + +You shouldn't install this package directly: use instead :: + + pip install "psycopg[c]" + +to install a version of the optimization package matching the ``psycopg`` +version installed. + +Installing this package requires some prerequisites: check `Local +installation`__ in the documentation. Without a C compiler and some library +headers install *will fail*: this is not a bug. + +If you are unable to meet the prerequisite needed you might want to install +``psycopg[binary]`` instead: look for `Binary installation`__ in the +documentation. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #local-installation +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #binary-installation + +Please read `the project readme`__ and `the installation documentation`__ for +more details. + +.. __: https://github.com/psycopg/psycopg#readme +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg_c/psycopg_c/.gitignore b/psycopg_c/psycopg_c/.gitignore new file mode 100644 index 0000000..36edb64 --- /dev/null +++ b/psycopg_c/psycopg_c/.gitignore @@ -0,0 +1,4 @@ +/*.so +_psycopg.c +pq.c +*.html diff --git a/psycopg_c/psycopg_c/__init__.py b/psycopg_c/psycopg_c/__init__.py new file mode 100644 index 0000000..14db92b --- /dev/null +++ b/psycopg_c/psycopg_c/__init__.py @@ -0,0 +1,14 @@ +""" +psycopg -- PostgreSQL database adapter for Python -- C optimization package +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys + +# This package shouldn't be imported before psycopg itself, or weird things +# will happen +if "psycopg" not in sys.modules: + raise ImportError("the psycopg package should be imported before psycopg_c") + +from .version import __version__ as __version__ # noqa diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi new file mode 100644 index 0000000..bd7c63d --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -0,0 +1,84 @@ +""" +Stub representaton of the public objects exposed by the _psycopg module. + +TODO: this should be generated by mypy's stubgen but it crashes with no +information. Will submit a bug. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Iterable, List, Optional, Sequence, Tuple + +from psycopg import pq +from psycopg import abc +from psycopg.rows import Row, RowMaker +from psycopg.adapt import AdaptersMap, PyFormat +from psycopg.pq.abc import PGconn, PGresult +from psycopg.connection import BaseConnection +from psycopg._compat import Deque + +class Transformer(abc.AdaptContext): + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + def __init__(self, context: Optional[abc.AdaptContext] = None): ... + @classmethod + def from_context(cls, context: Optional[abc.AdaptContext]) -> "Transformer": ... + @property + def connection(self) -> Optional[BaseConnection[Any]]: ... + @property + def encoding(self) -> str: ... + @property + def adapters(self) -> AdaptersMap: ... + @property + def pgresult(self) -> Optional[PGresult]: ... + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None, + ) -> None: ... + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ... + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ... + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[abc.Buffer]]: ... + def as_literal(self, obj: Any) -> bytes: ... + def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ... + def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ... + def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ... + def load_sequence( + self, record: Sequence[Optional[abc.Buffer]] + ) -> Tuple[Any, ...]: ... + def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ... + +# Generators +def connect(conninfo: str) -> abc.PQGenConn[PGconn]: ... +def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... +def send(pgconn: PGconn) -> abc.PQGen[None]: ... +def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... +def fetch(pgconn: PGconn) -> abc.PQGen[Optional[PGresult]]: ... +def pipeline_communicate( + pgconn: PGconn, commands: Deque[abc.PipelineCommand] +) -> abc.PQGen[List[List[PGresult]]]: ... +def wait_c( + gen: abc.PQGen[abc.RV], fileno: int, timeout: Optional[float] = None +) -> abc.RV: ... + +# Copy support +def format_row_text( + row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None +) -> bytearray: ... +def format_row_binary( + row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None +) -> bytearray: ... +def parse_row_text(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ... +def parse_row_binary(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ... + +# Arrays optimization +def array_load_text( + data: abc.Buffer, loader: abc.Loader, delimiter: bytes = b"," +) -> List[Any]: ... +def array_load_binary(data: abc.Buffer, tx: abc.Transformer) -> List[Any]: ... + +# vim: set syntax=python: diff --git a/psycopg_c/psycopg_c/_psycopg.pyx b/psycopg_c/psycopg_c/_psycopg.pyx new file mode 100644 index 0000000..9d2b8ba --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg.pyx @@ -0,0 +1,48 @@ +""" +psycopg_c._psycopg optimization module. + +The module contains optimized C code used in preference to Python code +if a compiler is available. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c cimport pq +from psycopg_c.pq cimport libpq +from psycopg_c._psycopg cimport oids + +import logging + +from psycopg.pq import Format as _pq_Format +from psycopg._enums import PyFormat as _py_Format + +logger = logging.getLogger("psycopg") + +PQ_TEXT = _pq_Format.TEXT +PQ_BINARY = _pq_Format.BINARY + +PG_AUTO = _py_Format.AUTO +PG_TEXT = _py_Format.TEXT +PG_BINARY = _py_Format.BINARY + + +cdef extern from *: + """ +#ifndef ARRAYSIZE +#define ARRAYSIZE(a) ((sizeof(a) / sizeof(*(a)))) +#endif + """ + int ARRAYSIZE(void *array) + + +include "_psycopg/adapt.pyx" +include "_psycopg/copy.pyx" +include "_psycopg/generators.pyx" +include "_psycopg/transform.pyx" +include "_psycopg/waiting.pyx" + +include "types/array.pyx" +include "types/datetime.pyx" +include "types/numeric.pyx" +include "types/bool.pyx" +include "types/string.pyx" diff --git a/psycopg_c/psycopg_c/_psycopg/__init__.pxd b/psycopg_c/psycopg_c/_psycopg/__init__.pxd new file mode 100644 index 0000000..db22deb --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/__init__.pxd @@ -0,0 +1,9 @@ +""" +psycopg_c._psycopg cython module. + +This file is necessary to allow c-importing pxd files from this directory. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c._psycopg cimport oids diff --git a/psycopg_c/psycopg_c/_psycopg/adapt.pyx b/psycopg_c/psycopg_c/_psycopg/adapt.pyx new file mode 100644 index 0000000..a6d8e6a --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/adapt.pyx @@ -0,0 +1,171 @@ +""" +C implementation of the adaptation system. + +This module maps each Python adaptation function to a C adaptation function. +Notice that C adaptation functions have a different signature because they can +avoid making a memory copy, however this makes impossible to expose them to +Python. + +This module exposes facilities to map the builtin adapters in python to +equivalent C implementations. + +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any + +cimport cython + +from libc.string cimport memcpy, memchr +from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize +from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_AS_STRING + +from psycopg_c.pq cimport _buffer_as_string_and_size, Escaping + +from psycopg import errors as e +from psycopg.pq.misc import error_message + + +@cython.freelist(8) +cdef class CDumper: + + cdef readonly object cls + cdef pq.PGconn _pgconn + + oid = oids.INVALID_OID + + def __cinit__(self, cls, context: Optional[AdaptContext] = None): + self.cls = cls + conn = context.connection if context is not None else None + self._pgconn = conn.pgconn if conn is not None else None + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + """Store the Postgres representation *obj* into *rv* at *offset* + + Return the number of bytes written to rv or -1 on Python exception. + + Subclasses must implement this method. The `dump()` implementation + transforms the result of this method to a bytearray so that it can be + returned to Python. + + The function interface allows C code to use this method automatically + to create larger buffers, e.g. for copy, composite objects, etc. + + Implementation note: as you will always need to make sure that rv + has enough space to include what you want to dump, `ensure_size()` + might probably come handy. + """ + raise NotImplementedError() + + def dump(self, obj): + """Return the Postgres representation of *obj* as Python array of bytes""" + cdef rv = PyByteArray_FromStringAndSize("", 0) + cdef Py_ssize_t length = self.cdump(obj, rv, 0) + PyByteArray_Resize(rv, length) + return rv + + def quote(self, obj): + cdef char *ptr + cdef char *ptr_out + cdef Py_ssize_t length + + value = self.dump(obj) + + if self._pgconn is not None: + esc = Escaping(self._pgconn) + # escaping and quoting + return esc.escape_literal(value) + + # This path is taken when quote is asked without a connection, + # usually it means by psycopg.sql.quote() or by + # 'Composible.as_string(None)'. Most often than not this is done by + # someone generating a SQL file to consume elsewhere. + + rv = PyByteArray_FromStringAndSize("", 0) + + # No quoting, only quote escaping, random bs escaping. See further. + esc = Escaping() + out = esc.escape_string(value) + + _buffer_as_string_and_size(out, &ptr, &length) + + if not memchr(ptr, b'\\', length): + # If the string has no backslash, the result is correct and we + # don't need to bother with standard_conforming_strings. + PyByteArray_Resize(rv, length + 2) # Must include the quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b"'" + memcpy(ptr_out + 1, ptr, length) + ptr_out[length + 1] = b"'" + return rv + + # The libpq has a crazy behaviour: PQescapeString uses the last + # standard_conforming_strings setting seen on a connection. This + # means that backslashes might be escaped or might not. + # + # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH, + # if scs is off, '\\' raises a warning and '\' is an error. + # + # Check what the libpq does, and if it doesn't escape the backslash + # let's do it on our own. Never mind the race condition. + PyByteArray_Resize(rv, length + 4) # Must include " E'...'" quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + 3, ptr, length) + ptr_out[length + 3] = b"'" + + if esc.escape_string(b"\\") == b"\\": + rv = bytes(rv).replace(b"\\", b"\\\\") + return rv + + cpdef object get_key(self, object obj, object format): + return self.cls + + cpdef object upgrade(self, object obj, object format): + return self + + @staticmethod + cdef char *ensure_size(bytearray ba, Py_ssize_t offset, Py_ssize_t size) except NULL: + """ + Grow *ba*, if necessary, to contains at least *size* bytes after *offset* + + Return the pointer in the bytearray at *offset*, i.e. the place where + you want to write *size* bytes. + """ + cdef Py_ssize_t curr_size = PyByteArray_GET_SIZE(ba) + cdef Py_ssize_t new_size = offset + size + if curr_size < new_size: + PyByteArray_Resize(ba, new_size) + + return PyByteArray_AS_STRING(ba) + offset + + +@cython.freelist(8) +cdef class CLoader: + cdef public libpq.Oid oid + cdef pq.PGconn _pgconn + + def __cinit__(self, int oid, context: Optional[AdaptContext] = None): + self.oid = oid + conn = context.connection if context is not None else None + self._pgconn = conn.pgconn if conn is not None else None + + cdef object cload(self, const char *data, size_t length): + raise NotImplementedError() + + def load(self, object data) -> Any: + cdef char *ptr + cdef Py_ssize_t length + _buffer_as_string_and_size(data, &ptr, &length) + return self.cload(ptr, length) + + +cdef class _CRecursiveLoader(CLoader): + + cdef Transformer _tx + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + self._tx = Transformer.from_context(context) diff --git a/psycopg_c/psycopg_c/_psycopg/copy.pyx b/psycopg_c/psycopg_c/_psycopg/copy.pyx new file mode 100644 index 0000000..b943095 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/copy.pyx @@ -0,0 +1,340 @@ +""" +C optimised functions for the copy system. + +""" + +# Copyright (C) 2020 The Psycopg Team + +from libc.string cimport memcpy +from libc.stdint cimport uint16_t, uint32_t, int32_t +from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize +from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE +from cpython.memoryview cimport PyMemoryView_FromObject + +from psycopg_c._psycopg cimport endian +from psycopg_c.pq cimport ViewBuffer + +from psycopg import errors as e + +cdef int32_t _binary_null = -1 + + +def format_row_binary( + row: Sequence[Any], tx: Transformer, out: bytearray = None +) -> bytearray: + """Convert a row of adapted data to the data to send for binary copy""" + cdef Py_ssize_t rowlen = len(row) + cdef uint16_t berowlen = endian.htobe16(<int16_t>rowlen) + + cdef Py_ssize_t pos # offset in 'out' where to write + if out is None: + out = PyByteArray_FromStringAndSize("", 0) + pos = 0 + else: + pos = PyByteArray_GET_SIZE(out) + + # let's start from a nice chunk + # (larger than most fixed size; for variable ones, oh well, we'll resize it) + cdef char *target = CDumper.ensure_size( + out, pos, sizeof(berowlen) + 20 * rowlen) + + # Write the number of fields as network-order 16 bits + memcpy(target, <void *>&berowlen, sizeof(berowlen)) + pos += sizeof(berowlen) + + cdef Py_ssize_t size + cdef uint32_t besize + cdef char *buf + cdef int i + cdef PyObject *fmt = <PyObject *>PG_BINARY + cdef PyObject *row_dumper + + if not tx._row_dumpers: + tx._row_dumpers = PyList_New(rowlen) + + dumpers = tx._row_dumpers + + for i in range(rowlen): + item = row[i] + if item is None: + target = CDumper.ensure_size(out, pos, sizeof(_binary_null)) + memcpy(target, <void *>&_binary_null, sizeof(_binary_null)) + pos += sizeof(_binary_null) + continue + + row_dumper = PyList_GET_ITEM(dumpers, i) + if not row_dumper: + row_dumper = tx.get_row_dumper(<PyObject *>item, fmt) + Py_INCREF(<object>row_dumper) + PyList_SET_ITEM(dumpers, i, <object>row_dumper) + + if (<RowDumper>row_dumper).cdumper is not None: + # A cdumper can resize if necessary and copy in place + size = (<RowDumper>row_dumper).cdumper.cdump( + item, out, pos + sizeof(besize)) + # Also add the size of the item, before the item + besize = endian.htobe32(<int32_t>size) + target = PyByteArray_AS_STRING(out) # might have been moved by cdump + memcpy(target + pos, <void *>&besize, sizeof(besize)) + else: + # A Python dumper, gotta call it and extract its juices + b = PyObject_CallFunctionObjArgs( + (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL) + _buffer_as_string_and_size(b, &buf, &size) + target = CDumper.ensure_size(out, pos, size + sizeof(besize)) + besize = endian.htobe32(<int32_t>size) + memcpy(target, <void *>&besize, sizeof(besize)) + memcpy(target + sizeof(besize), buf, size) + + pos += size + sizeof(besize) + + # Resize to the final size + PyByteArray_Resize(out, pos) + return out + + +def format_row_text( + row: Sequence[Any], tx: Transformer, out: bytearray = None +) -> bytearray: + cdef Py_ssize_t pos # offset in 'out' where to write + if out is None: + out = PyByteArray_FromStringAndSize("", 0) + pos = 0 + else: + pos = PyByteArray_GET_SIZE(out) + + cdef Py_ssize_t rowlen = len(row) + + if rowlen == 0: + PyByteArray_Resize(out, pos + 1) + out[pos] = b"\n" + return out + + cdef Py_ssize_t size, tmpsize + cdef char *buf + cdef int i, j + cdef unsigned char *target + cdef int nesc = 0 + cdef int with_tab + cdef PyObject *fmt = <PyObject *>PG_TEXT + cdef PyObject *row_dumper + + for i in range(rowlen): + # Include the tab before the data, so it gets included in the resizes + with_tab = i > 0 + + item = row[i] + if item is None: + if with_tab: + target = <unsigned char *>CDumper.ensure_size(out, pos, 3) + memcpy(target, b"\t\\N", 3) + pos += 3 + else: + target = <unsigned char *>CDumper.ensure_size(out, pos, 2) + memcpy(target, b"\\N", 2) + pos += 2 + continue + + row_dumper = tx.get_row_dumper(<PyObject *>item, fmt) + if (<RowDumper>row_dumper).cdumper is not None: + # A cdumper can resize if necessary and copy in place + size = (<RowDumper>row_dumper).cdumper.cdump( + item, out, pos + with_tab) + target = <unsigned char *>PyByteArray_AS_STRING(out) + pos + else: + # A Python dumper, gotta call it and extract its juices + b = PyObject_CallFunctionObjArgs( + (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL) + _buffer_as_string_and_size(b, &buf, &size) + target = <unsigned char *>CDumper.ensure_size(out, pos, size + with_tab) + memcpy(target + with_tab, buf, size) + + # Prepend a tab to the data just written + if with_tab: + target[0] = b"\t" + target += 1 + pos += 1 + + # Now from pos to pos + size there is a textual representation: it may + # contain chars to escape. Scan to find how many such chars there are. + for j in range(size): + if copy_escape_lut[target[j]]: + nesc += 1 + + # If there is any char to escape, walk backwards pushing the chars + # forward and interspersing backslashes. + if nesc > 0: + tmpsize = size + nesc + target = <unsigned char *>CDumper.ensure_size(out, pos, tmpsize) + for j in range(<int>size - 1, -1, -1): + if copy_escape_lut[target[j]]: + target[j + nesc] = copy_escape_lut[target[j]] + nesc -= 1 + target[j + nesc] = b"\\" + if nesc <= 0: + break + else: + target[j + nesc] = target[j] + pos += tmpsize + else: + pos += size + + # Resize to the final size, add the newline + PyByteArray_Resize(out, pos + 1) + out[pos] = b"\n" + return out + + +def parse_row_binary(data, tx: Transformer) -> Tuple[Any, ...]: + cdef unsigned char *ptr + cdef Py_ssize_t bufsize + _buffer_as_string_and_size(data, <char **>&ptr, &bufsize) + cdef unsigned char *bufend = ptr + bufsize + + cdef uint16_t benfields = (<uint16_t *>ptr)[0] + cdef int nfields = endian.be16toh(benfields) + ptr += sizeof(benfields) + cdef list row = PyList_New(nfields) + + cdef int col + cdef int32_t belength + cdef Py_ssize_t length + + for col in range(nfields): + memcpy(&belength, ptr, sizeof(belength)) + ptr += sizeof(belength) + if belength == _binary_null: + field = None + else: + length = endian.be32toh(belength) + if ptr + length > bufend: + raise e.DataError("bad copy data: length exceeding data") + field = PyMemoryView_FromObject( + ViewBuffer._from_buffer(data, ptr, length)) + ptr += length + + Py_INCREF(field) + PyList_SET_ITEM(row, col, field) + + return tx.load_sequence(row) + + +def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]: + cdef unsigned char *fstart + cdef Py_ssize_t size + _buffer_as_string_and_size(data, <char **>&fstart, &size) + + # politely assume that the number of fields will be what in the result + cdef int nfields = tx._nfields + cdef list row = PyList_New(nfields) + + cdef unsigned char *fend + cdef unsigned char *rowend = fstart + size + cdef unsigned char *src + cdef unsigned char *tgt + cdef int col + cdef int num_bs + + for col in range(nfields): + fend = fstart + num_bs = 0 + # Scan to the end of the field, remember if you see any backslash + while fend[0] != b'\t' and fend[0] != b'\n' and fend < rowend: + if fend[0] == b'\\': + num_bs += 1 + # skip the next char to avoid counting escaped backslashes twice + fend += 1 + fend += 1 + + # Check if we stopped for the right reason + if fend >= rowend: + raise e.DataError("bad copy data: field delimiter not found") + elif fend[0] == b'\t' and col == nfields - 1: + raise e.DataError("bad copy data: got a tab at the end of the row") + elif fend[0] == b'\n' and col != nfields - 1: + raise e.DataError( + "bad copy format: got a newline before the end of the row") + + # Is this a NULL? + if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N': + field = None + + # Is this a field with no backslash? + elif num_bs == 0: + # Nothing to unescape: we don't need a copy + field = PyMemoryView_FromObject( + ViewBuffer._from_buffer(data, fstart, fend - fstart)) + + # This is a field containing backslashes + else: + # We need a copy of the buffer to unescape + field = PyByteArray_FromStringAndSize("", 0) + PyByteArray_Resize(field, fend - fstart - num_bs) + tgt = <unsigned char *>PyByteArray_AS_STRING(field) + src = fstart + while (src < fend): + if src[0] != b'\\': + tgt[0] = src[0] + else: + src += 1 + tgt[0] = copy_unescape_lut[src[0]] + src += 1 + tgt += 1 + + Py_INCREF(field) + PyList_SET_ITEM(row, col, field) + + # Start of the field + fstart = fend + 1 + + # Convert the array of buffers into Python objects + return tx.load_sequence(row) + + +cdef extern from *: + """ +/* handle chars to (un)escape in text copy representation */ +/* '\b', '\t', '\n', '\v', '\f', '\r', '\\' */ + +/* Escaping chars */ +static const char copy_escape_lut[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 98, 116, 110, 118, 102, 114, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 92, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +/* Conversion of escaped to unescaped chars */ +static const char copy_unescape_lut[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 8, 99, 100, 101, 12, 103, 104, 105, 106, 107, 108, 109, 10, 111, +112, 113, 13, 115, 9, 117, 11, 119, 120, 121, 122, 123, 124, 125, 126, 127, +128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, +144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, +160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, +176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, +192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, +208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, +224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, +240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, +}; + """ + const char[256] copy_escape_lut + const char[256] copy_unescape_lut diff --git a/psycopg_c/psycopg_c/_psycopg/endian.pxd b/psycopg_c/psycopg_c/_psycopg/endian.pxd new file mode 100644 index 0000000..44e7305 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/endian.pxd @@ -0,0 +1,155 @@ +""" +Access to endian conversion function +""" + +# Copyright (C) 2020 The Psycopg Team + +from libc.stdint cimport uint16_t, uint32_t, uint64_t + +cdef extern from * nogil: + # from https://gist.github.com/panzi/6856583 + # Improved in: + # https://github.com/linux-sunxi/sunxi-tools/blob/master/include/portable_endian.h + """ +// "License": Public Domain +// I, Mathias Panzenböck, place this file hereby into the public domain. Use it at your own risk for whatever you like. +// In case there are jurisdictions that don't support putting things in the public domain you can also consider it to +// be "dual licensed" under the BSD, MIT and Apache licenses, if you want to. This code is trivial anyway. Consider it +// an example on how to get the endian conversion functions on different platforms. + +#ifndef PORTABLE_ENDIAN_H__ +#define PORTABLE_ENDIAN_H__ + +#if (defined(_WIN16) || defined(_WIN32) || defined(_WIN64)) && !defined(__WINDOWS__) + +# define __WINDOWS__ + +#endif + +#if defined(__linux__) || defined(__CYGWIN__) + +# include <endian.h> + +#elif defined(__APPLE__) + +# include <libkern/OSByteOrder.h> + +# define htobe16(x) OSSwapHostToBigInt16(x) +# define htole16(x) OSSwapHostToLittleInt16(x) +# define be16toh(x) OSSwapBigToHostInt16(x) +# define le16toh(x) OSSwapLittleToHostInt16(x) + +# define htobe32(x) OSSwapHostToBigInt32(x) +# define htole32(x) OSSwapHostToLittleInt32(x) +# define be32toh(x) OSSwapBigToHostInt32(x) +# define le32toh(x) OSSwapLittleToHostInt32(x) + +# define htobe64(x) OSSwapHostToBigInt64(x) +# define htole64(x) OSSwapHostToLittleInt64(x) +# define be64toh(x) OSSwapBigToHostInt64(x) +# define le64toh(x) OSSwapLittleToHostInt64(x) + +# define __BYTE_ORDER BYTE_ORDER +# define __BIG_ENDIAN BIG_ENDIAN +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __PDP_ENDIAN PDP_ENDIAN + +#elif defined(__OpenBSD__) || defined(__NetBSD__) || defined(__FreeBSD__) || defined(__DragonFly__) + +# include <sys/endian.h> + +/* For functions still missing, try to substitute 'historic' OpenBSD names */ +#ifndef be16toh +# define be16toh(x) betoh16(x) +#endif +#ifndef le16toh +# define le16toh(x) letoh16(x) +#endif +#ifndef be32toh +# define be32toh(x) betoh32(x) +#endif +#ifndef le32toh +# define le32toh(x) letoh32(x) +#endif +#ifndef be64toh +# define be64toh(x) betoh64(x) +#endif +#ifndef le64toh +# define le64toh(x) letoh64(x) +#endif + +#elif defined(__WINDOWS__) + +# include <winsock2.h> +# ifndef _MSC_VER +# include <sys/param.h> +# endif + +# if BYTE_ORDER == LITTLE_ENDIAN + +# define htobe16(x) htons(x) +# define htole16(x) (x) +# define be16toh(x) ntohs(x) +# define le16toh(x) (x) + +# define htobe32(x) htonl(x) +# define htole32(x) (x) +# define be32toh(x) ntohl(x) +# define le32toh(x) (x) + +# define htobe64(x) htonll(x) +# define htole64(x) (x) +# define be64toh(x) ntohll(x) +# define le64toh(x) (x) + +# elif BYTE_ORDER == BIG_ENDIAN + + /* that would be xbox 360 */ +# define htobe16(x) (x) +# define htole16(x) __builtin_bswap16(x) +# define be16toh(x) (x) +# define le16toh(x) __builtin_bswap16(x) + +# define htobe32(x) (x) +# define htole32(x) __builtin_bswap32(x) +# define be32toh(x) (x) +# define le32toh(x) __builtin_bswap32(x) + +# define htobe64(x) (x) +# define htole64(x) __builtin_bswap64(x) +# define be64toh(x) (x) +# define le64toh(x) __builtin_bswap64(x) + +# else + +# error byte order not supported + +# endif + +# define __BYTE_ORDER BYTE_ORDER +# define __BIG_ENDIAN BIG_ENDIAN +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __PDP_ENDIAN PDP_ENDIAN + +#else + +# error platform not supported + +#endif + +#endif + """ + cdef uint16_t htobe16(uint16_t host_16bits) + cdef uint16_t htole16(uint16_t host_16bits) + cdef uint16_t be16toh(uint16_t big_endian_16bits) + cdef uint16_t le16toh(uint16_t little_endian_16bits) + + cdef uint32_t htobe32(uint32_t host_32bits) + cdef uint32_t htole32(uint32_t host_32bits) + cdef uint32_t be32toh(uint32_t big_endian_32bits) + cdef uint32_t le32toh(uint32_t little_endian_32bits) + + cdef uint64_t htobe64(uint64_t host_64bits) + cdef uint64_t htole64(uint64_t host_64bits) + cdef uint64_t be64toh(uint64_t big_endian_64bits) + cdef uint64_t le64toh(uint64_t little_endian_64bits) diff --git a/psycopg_c/psycopg_c/_psycopg/generators.pyx b/psycopg_c/psycopg_c/_psycopg/generators.pyx new file mode 100644 index 0000000..9ce9e54 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/generators.pyx @@ -0,0 +1,276 @@ +""" +C implementation of generators for the communication protocols with the libpq +""" + +# Copyright (C) 2020 The Psycopg Team + +from cpython.object cimport PyObject_CallFunctionObjArgs + +from typing import List + +from psycopg import errors as e +from psycopg.pq import abc, error_message +from psycopg.abc import PipelineCommand, PQGen +from psycopg._enums import Wait, Ready +from psycopg._compat import Deque +from psycopg._encodings import conninfo_encoding + +cdef object WAIT_W = Wait.W +cdef object WAIT_R = Wait.R +cdef object WAIT_RW = Wait.RW +cdef object PY_READY_R = Ready.R +cdef object PY_READY_W = Ready.W +cdef object PY_READY_RW = Ready.RW +cdef int READY_R = Ready.R +cdef int READY_W = Ready.W +cdef int READY_RW = Ready.RW + +def connect(conninfo: str) -> PQGenConn[abc.PGconn]: + """ + Generator to create a database connection without blocking. + + """ + cdef pq.PGconn conn = pq.PGconn.connect_start(conninfo.encode()) + cdef libpq.PGconn *pgconn_ptr = conn._pgconn_ptr + cdef int conn_status = libpq.PQstatus(pgconn_ptr) + cdef int poll_status + + while True: + if conn_status == libpq.CONNECTION_BAD: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection is bad: {error_message(conn, encoding=encoding)}", + pgconn=conn + ) + + with nogil: + poll_status = libpq.PQconnectPoll(pgconn_ptr) + + if poll_status == libpq.PGRES_POLLING_OK: + break + elif poll_status == libpq.PGRES_POLLING_READING: + yield (libpq.PQsocket(pgconn_ptr), WAIT_R) + elif poll_status == libpq.PGRES_POLLING_WRITING: + yield (libpq.PQsocket(pgconn_ptr), WAIT_W) + elif poll_status == libpq.PGRES_POLLING_FAILED: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection failed: {error_message(conn, encoding=encoding)}", + pgconn=conn + ) + else: + raise e.InternalError( + f"unexpected poll status: {poll_status}", pgconn=conn + ) + + conn.nonblocking = 1 + return conn + + +def execute(pq.PGconn pgconn) -> PQGen[List[abc.PGresult]]: + """ + Generator sending a query and returning results without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + Return the list of results returned by the database (whether success + or error). + """ + yield from send(pgconn) + rv = yield from fetch_many(pgconn) + return rv + + +def send(pq.PGconn pgconn) -> PQGen[None]: + """ + Generator to send a query to the server without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + After this generator has finished you may want to cycle using `fetch()` + to retrieve the results available. + """ + cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr + cdef int status + cdef int cires + + while True: + if pgconn.flush() == 0: + break + + status = yield WAIT_RW + if status & READY_R: + with nogil: + # This call may read notifies which will be saved in the + # PGconn buffer and passed to Python later. + cires = libpq.PQconsumeInput(pgconn_ptr) + if 1 != cires: + raise e.OperationalError( + f"consuming input failed: {error_message(pgconn)}") + + +def fetch_many(pq.PGconn pgconn) -> PQGen[List[PGresult]]: + """ + Generator retrieving results from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return the list of results returned by the database (whether success + or error). + """ + cdef list results = [] + cdef int status + cdef pq.PGresult result + cdef libpq.PGresult *pgres + + while True: + result = yield from fetch(pgconn) + if result is None: + break + results.append(result) + pgres = result._pgresult_ptr + + status = libpq.PQresultStatus(pgres) + if ( + status == libpq.PGRES_COPY_IN + or status == libpq.PGRES_COPY_OUT + or status == libpq.PGRES_COPY_BOTH + ): + # After entering copy mode the libpq will create a phony result + # for every request so let's break the endless loop. + break + + if status == libpq.PGRES_PIPELINE_SYNC: + # PIPELINE_SYNC is not followed by a NULL, but we return it alone + # similarly to other result sets. + break + + return results + + +def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]: + """ + Generator retrieving a single result from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return a result from the database (whether success or error). + """ + cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr + cdef int cires, ibres + cdef libpq.PGresult *pgres + + with nogil: + ibres = libpq.PQisBusy(pgconn_ptr) + if ibres: + yield WAIT_R + while True: + with nogil: + cires = libpq.PQconsumeInput(pgconn_ptr) + if cires == 1: + ibres = libpq.PQisBusy(pgconn_ptr) + + if 1 != cires: + raise e.OperationalError( + f"consuming input failed: {error_message(pgconn)}") + if not ibres: + break + yield WAIT_R + + _consume_notifies(pgconn) + + with nogil: + pgres = libpq.PQgetResult(pgconn_ptr) + if pgres is NULL: + return None + return pq.PGresult._from_ptr(pgres) + + +def pipeline_communicate( + pq.PGconn pgconn, commands: Deque[PipelineCommand] +) -> PQGen[List[List[PGresult]]]: + """Generator to send queries from a connection in pipeline mode while also + receiving results. + + Return a list results, including single PIPELINE_SYNC elements. + """ + cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr + cdef int cires + cdef int status + cdef int ready + cdef libpq.PGresult *pgres + cdef list res = [] + cdef list results = [] + cdef pq.PGresult r + + while True: + ready = yield WAIT_RW + + if ready & READY_R: + with nogil: + cires = libpq.PQconsumeInput(pgconn_ptr) + if 1 != cires: + raise e.OperationalError( + f"consuming input failed: {error_message(pgconn)}") + + _consume_notifies(pgconn) + + res: List[PGresult] = [] + while True: + with nogil: + ibres = libpq.PQisBusy(pgconn_ptr) + if ibres: + break + pgres = libpq.PQgetResult(pgconn_ptr) + + if pgres is NULL: + if not res: + break + results.append(res) + res = [] + else: + status = libpq.PQresultStatus(pgres) + r = pq.PGresult._from_ptr(pgres) + if status == libpq.PGRES_PIPELINE_SYNC: + results.append([r]) + break + else: + res.append(r) + + if ready & READY_W: + pgconn.flush() + if not commands: + break + commands.popleft()() + + return results + + +cdef int _consume_notifies(pq.PGconn pgconn) except -1: + cdef object notify_handler = pgconn.notify_handler + cdef libpq.PGconn *pgconn_ptr + cdef libpq.PGnotify *notify + + if notify_handler is not None: + while True: + pynotify = pgconn.notifies() + if pynotify is None: + break + PyObject_CallFunctionObjArgs( + notify_handler, <PyObject *>pynotify, NULL + ) + else: + pgconn_ptr = pgconn._pgconn_ptr + while True: + notify = libpq.PQnotifies(pgconn_ptr) + if notify is NULL: + break + libpq.PQfreemem(notify) + + return 0 diff --git a/psycopg_c/psycopg_c/_psycopg/oids.pxd b/psycopg_c/psycopg_c/_psycopg/oids.pxd new file mode 100644 index 0000000..2a864c4 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/oids.pxd @@ -0,0 +1,92 @@ +""" +Constants to refer to OIDS in C +""" + +# Copyright (C) 2020 The Psycopg Team + +# Use tools/update_oids.py to update this data. + +cdef enum: + INVALID_OID = 0 + + # autogenerated: start + + # Generated from PostgreSQL 15.0 + + ACLITEM_OID = 1033 + BIT_OID = 1560 + BOOL_OID = 16 + BOX_OID = 603 + BPCHAR_OID = 1042 + BYTEA_OID = 17 + CHAR_OID = 18 + CID_OID = 29 + CIDR_OID = 650 + CIRCLE_OID = 718 + DATE_OID = 1082 + DATEMULTIRANGE_OID = 4535 + DATERANGE_OID = 3912 + FLOAT4_OID = 700 + FLOAT8_OID = 701 + GTSVECTOR_OID = 3642 + INET_OID = 869 + INT2_OID = 21 + INT2VECTOR_OID = 22 + INT4_OID = 23 + INT4MULTIRANGE_OID = 4451 + INT4RANGE_OID = 3904 + INT8_OID = 20 + INT8MULTIRANGE_OID = 4536 + INT8RANGE_OID = 3926 + INTERVAL_OID = 1186 + JSON_OID = 114 + JSONB_OID = 3802 + JSONPATH_OID = 4072 + LINE_OID = 628 + LSEG_OID = 601 + MACADDR_OID = 829 + MACADDR8_OID = 774 + MONEY_OID = 790 + NAME_OID = 19 + NUMERIC_OID = 1700 + NUMMULTIRANGE_OID = 4532 + NUMRANGE_OID = 3906 + OID_OID = 26 + OIDVECTOR_OID = 30 + PATH_OID = 602 + PG_LSN_OID = 3220 + POINT_OID = 600 + POLYGON_OID = 604 + RECORD_OID = 2249 + REFCURSOR_OID = 1790 + REGCLASS_OID = 2205 + REGCOLLATION_OID = 4191 + REGCONFIG_OID = 3734 + REGDICTIONARY_OID = 3769 + REGNAMESPACE_OID = 4089 + REGOPER_OID = 2203 + REGOPERATOR_OID = 2204 + REGPROC_OID = 24 + REGPROCEDURE_OID = 2202 + REGROLE_OID = 4096 + REGTYPE_OID = 2206 + TEXT_OID = 25 + TID_OID = 27 + TIME_OID = 1083 + TIMESTAMP_OID = 1114 + TIMESTAMPTZ_OID = 1184 + TIMETZ_OID = 1266 + TSMULTIRANGE_OID = 4533 + TSQUERY_OID = 3615 + TSRANGE_OID = 3908 + TSTZMULTIRANGE_OID = 4534 + TSTZRANGE_OID = 3910 + TSVECTOR_OID = 3614 + TXID_SNAPSHOT_OID = 2970 + UUID_OID = 2950 + VARBIT_OID = 1562 + VARCHAR_OID = 1043 + XID_OID = 28 + XID8_OID = 5069 + XML_OID = 142 + # autogenerated: end diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx new file mode 100644 index 0000000..fc69725 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -0,0 +1,640 @@ +""" +Helper object to transform values between Python and PostgreSQL + +Cython implementation: can access to lower level C features without creating +too many temporary Python objects and performing less memory copying. + +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython +from cpython.ref cimport Py_INCREF, Py_DECREF +from cpython.set cimport PySet_Add, PySet_Contains +from cpython.dict cimport PyDict_GetItem, PyDict_SetItem +from cpython.list cimport ( + PyList_New, PyList_CheckExact, + PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE) +from cpython.bytes cimport PyBytes_AS_STRING +from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM +from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs + +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from psycopg import errors as e +from psycopg.pq import Format as PqFormat +from psycopg.rows import Row, RowMaker +from psycopg._encodings import pgconn_encoding + +NoneType = type(None) + +# internal structure: you are not supposed to know this. But it's worth some +# 10% of the innermost loop, so I'm willing to ask for forgiveness later... + +ctypedef struct PGresAttValue: + int len + char *value + +ctypedef struct pg_result_int: + # NOTE: it would be advised that we don't know this structure's content + int ntups + int numAttributes + libpq.PGresAttDesc *attDescs + PGresAttValue **tuples + # ...more members, which we ignore + + +@cython.freelist(16) +cdef class RowLoader: + cdef CLoader cloader + cdef object pyloader + cdef object loadfunc + + +@cython.freelist(16) +cdef class RowDumper: + cdef CDumper cdumper + cdef object pydumper + cdef object dumpfunc + cdef object oid + cdef object format + + +cdef class Transformer: + """ + An object that can adapt efficiently between Python and PostgreSQL. + + The life cycle of the object is the query, so it is assumed that attributes + such as the server version or the connection encoding will not change. The + object have its state so adapting several values of the same type can be + optimised. + + """ + + cdef readonly object connection + cdef readonly object adapters + cdef readonly object types + cdef readonly object formats + cdef str _encoding + cdef int _none_oid + + # mapping class -> Dumper instance (auto, text, binary) + cdef dict _auto_dumpers + cdef dict _text_dumpers + cdef dict _binary_dumpers + + # mapping oid -> Loader instance (text, binary) + cdef dict _text_loaders + cdef dict _binary_loaders + + # mapping oid -> Dumper instance (text, binary) + cdef dict _oid_text_dumpers + cdef dict _oid_binary_dumpers + + cdef pq.PGresult _pgresult + cdef int _nfields, _ntuples + cdef list _row_dumpers + cdef list _row_loaders + + cdef dict _oid_types + + def __cinit__(self, context: Optional["AdaptContext"] = None): + if context is not None: + self.adapters = context.adapters + self.connection = context.connection + else: + from psycopg import postgres + self.adapters = postgres.adapters + self.connection = None + + self.types = self.formats = None + self._none_oid = -1 + + @classmethod + def from_context(cls, context: Optional["AdaptContext"]): + """ + Return a Transformer from an AdaptContext. + + If the context is a Transformer instance, just return it. + """ + return _tx_from_context(context) + + @property + def encoding(self) -> str: + if not self._encoding: + conn = self.connection + self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8" + return self._encoding + + @property + def pgresult(self) -> Optional[PGresult]: + return self._pgresult + + cpdef set_pgresult( + self, + pq.PGresult result, + object set_loaders = True, + object format = None + ): + self._pgresult = result + + if result is None: + self._nfields = self._ntuples = 0 + if set_loaders: + self._row_loaders = [] + return + + cdef libpq.PGresult *res = self._pgresult._pgresult_ptr + self._nfields = libpq.PQnfields(res) + self._ntuples = libpq.PQntuples(res) + + if not set_loaders: + return + + if not self._nfields: + self._row_loaders = [] + return + + if format is None: + format = libpq.PQfformat(res, 0) + + cdef list loaders = PyList_New(self._nfields) + cdef PyObject *row_loader + cdef object oid + + cdef int i + for i in range(self._nfields): + oid = libpq.PQftype(res, i) + row_loader = self._c_get_loader(<PyObject *>oid, <PyObject *>format) + Py_INCREF(<object>row_loader) + PyList_SET_ITEM(loaders, i, <object>row_loader) + + self._row_loaders = loaders + + def set_dumper_types(self, types: Sequence[int], format: Format) -> None: + cdef Py_ssize_t ntypes = len(types) + dumpers = PyList_New(ntypes) + cdef int i + for i in range(ntypes): + oid = types[i] + dumper_ptr = self.get_dumper_by_oid( + <PyObject *>oid, <PyObject *>format) + Py_INCREF(<object>dumper_ptr) + PyList_SET_ITEM(dumpers, i, <object>dumper_ptr) + + self._row_dumpers = dumpers + self.types = tuple(types) + self.formats = [format] * ntypes + + def set_loader_types(self, types: Sequence[int], format: Format) -> None: + self._c_loader_types(len(types), types, format) + + cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, object format): + cdef list loaders = PyList_New(ntypes) + + # these are used more as Python object than C + cdef PyObject *oid + cdef PyObject *row_loader + for i in range(ntypes): + oid = PyList_GET_ITEM(types, i) + row_loader = self._c_get_loader(oid, <PyObject *>format) + Py_INCREF(<object>row_loader) + PyList_SET_ITEM(loaders, i, <object>row_loader) + + self._row_loaders = loaders + + cpdef as_literal(self, obj): + cdef PyObject *row_dumper = self.get_row_dumper( + <PyObject *>obj, <PyObject *>PG_TEXT) + + if (<RowDumper>row_dumper).cdumper is not None: + dumper = (<RowDumper>row_dumper).cdumper + else: + dumper = (<RowDumper>row_dumper).pydumper + + rv = dumper.quote(obj) + oid = dumper.oid + # If the result is quoted and the oid not unknown or text, + # add an explicit type cast. + # Check the last char because the first one might be 'E'. + if oid and oid != oids.TEXT_OID and rv and rv[-1] == 39: + if self._oid_types is None: + self._oid_types = {} + type_ptr = PyDict_GetItem(<object>self._oid_types, oid) + if type_ptr == NULL: + type_sql = b"" + ti = self.adapters.types.get(oid) + if ti is not None: + if oid < 8192: + # builtin: prefer "timestamptz" to "timestamp with time zone" + type_sql = ti.name.encode(self.encoding) + else: + type_sql = ti.regtype.encode(self.encoding) + if oid == ti.array_oid: + type_sql += b"[]" + + type_ptr = <PyObject *>type_sql + PyDict_SetItem(<object>self._oid_types, oid, type_sql) + + if <object>type_ptr: + rv = b"%s::%s" % (rv, <object>type_ptr) + + return rv + + def get_dumper(self, obj, format) -> "Dumper": + cdef PyObject *row_dumper = self.get_row_dumper( + <PyObject *>obj, <PyObject *>format) + return (<RowDumper>row_dumper).pydumper + + cdef PyObject *get_row_dumper(self, PyObject *obj, PyObject *fmt) except NULL: + """ + Return a borrowed reference to the RowDumper for the given obj/fmt. + """ + # Fast path: return a Dumper class already instantiated from the same type + cdef PyObject *cache + cdef PyObject *ptr + cdef PyObject *ptr1 + cdef RowDumper row_dumper + + # Normally, the type of the object dictates how to dump it + key = type(<object>obj) + + # Establish where would the dumper be cached + bfmt = PyUnicode_AsUTF8String(<object>fmt) + cdef char cfmt = PyBytes_AS_STRING(bfmt)[0] + if cfmt == b's': + if self._auto_dumpers is None: + self._auto_dumpers = {} + cache = <PyObject *>self._auto_dumpers + elif cfmt == b'b': + if self._binary_dumpers is None: + self._binary_dumpers = {} + cache = <PyObject *>self._binary_dumpers + elif cfmt == b't': + if self._text_dumpers is None: + self._text_dumpers = {} + cache = <PyObject *>self._text_dumpers + else: + raise ValueError( + f"format should be a psycopg.adapt.Format, not {<object>fmt}") + + # Reuse an existing Dumper class for objects of the same type + ptr = PyDict_GetItem(<object>cache, key) + if ptr == NULL: + dcls = PyObject_CallFunctionObjArgs( + self.adapters.get_dumper, <PyObject *>key, fmt, NULL) + dumper = PyObject_CallFunctionObjArgs( + dcls, <PyObject *>key, <PyObject *>self, NULL) + + row_dumper = _as_row_dumper(dumper) + PyDict_SetItem(<object>cache, key, row_dumper) + ptr = <PyObject *>row_dumper + + # Check if the dumper requires an upgrade to handle this specific value + if (<RowDumper>ptr).cdumper is not None: + key1 = (<RowDumper>ptr).cdumper.get_key(<object>obj, <object>fmt) + else: + key1 = PyObject_CallFunctionObjArgs( + (<RowDumper>ptr).pydumper.get_key, obj, fmt, NULL) + if key1 is key: + return ptr + + # If it does, ask the dumper to create its own upgraded version + ptr1 = PyDict_GetItem(<object>cache, key1) + if ptr1 != NULL: + return ptr1 + + if (<RowDumper>ptr).cdumper is not None: + dumper = (<RowDumper>ptr).cdumper.upgrade(<object>obj, <object>fmt) + else: + dumper = PyObject_CallFunctionObjArgs( + (<RowDumper>ptr).pydumper.upgrade, obj, fmt, NULL) + + row_dumper = _as_row_dumper(dumper) + PyDict_SetItem(<object>cache, key1, row_dumper) + return <PyObject *>row_dumper + + cdef PyObject *get_dumper_by_oid(self, PyObject *oid, PyObject *fmt) except NULL: + """ + Return a borrowed reference to the RowDumper for the given oid/fmt. + """ + cdef PyObject *ptr + cdef PyObject *cache + cdef RowDumper row_dumper + + # Establish where would the dumper be cached + cdef int cfmt = <object>fmt + if cfmt == 0: + if self._oid_text_dumpers is None: + self._oid_text_dumpers = {} + cache = <PyObject *>self._oid_text_dumpers + elif cfmt == 1: + if self._oid_binary_dumpers is None: + self._oid_binary_dumpers = {} + cache = <PyObject *>self._oid_binary_dumpers + else: + raise ValueError( + f"format should be a psycopg.pq.Format, not {<object>fmt}") + + # Reuse an existing Dumper class for objects of the same type + ptr = PyDict_GetItem(<object>cache, <object>oid) + if ptr == NULL: + dcls = PyObject_CallFunctionObjArgs( + self.adapters.get_dumper_by_oid, oid, fmt, NULL) + dumper = PyObject_CallFunctionObjArgs( + dcls, <PyObject *>NoneType, <PyObject *>self, NULL) + + row_dumper = _as_row_dumper(dumper) + PyDict_SetItem(<object>cache, <object>oid, row_dumper) + ptr = <PyObject *>row_dumper + + return ptr + + cpdef dump_sequence(self, object params, object formats): + # Verify that they are not none and that PyList_GET_ITEM won't blow up + cdef Py_ssize_t nparams = len(params) + cdef list out = PyList_New(nparams) + + cdef int i + cdef PyObject *dumper_ptr # borrowed pointer to row dumper + cdef object dumped + cdef Py_ssize_t size + + if self._none_oid < 0: + self._none_oid = self.adapters.get_dumper(NoneType, "s").oid + + dumpers = self._row_dumpers + + if dumpers: + for i in range(nparams): + param = params[i] + if param is not None: + dumper_ptr = PyList_GET_ITEM(dumpers, i) + if (<RowDumper>dumper_ptr).cdumper is not None: + dumped = PyByteArray_FromStringAndSize("", 0) + size = (<RowDumper>dumper_ptr).cdumper.cdump( + param, <bytearray>dumped, 0) + PyByteArray_Resize(dumped, size) + else: + dumped = PyObject_CallFunctionObjArgs( + (<RowDumper>dumper_ptr).dumpfunc, + <PyObject *>param, NULL) + else: + dumped = None + + Py_INCREF(dumped) + PyList_SET_ITEM(out, i, dumped) + + return out + + cdef tuple types = PyTuple_New(nparams) + cdef list pqformats = PyList_New(nparams) + + for i in range(nparams): + param = params[i] + if param is not None: + dumper_ptr = self.get_row_dumper( + <PyObject *>param, <PyObject *>formats[i]) + if (<RowDumper>dumper_ptr).cdumper is not None: + dumped = PyByteArray_FromStringAndSize("", 0) + size = (<RowDumper>dumper_ptr).cdumper.cdump( + param, <bytearray>dumped, 0) + PyByteArray_Resize(dumped, size) + else: + dumped = PyObject_CallFunctionObjArgs( + (<RowDumper>dumper_ptr).dumpfunc, + <PyObject *>param, NULL) + oid = (<RowDumper>dumper_ptr).oid + fmt = (<RowDumper>dumper_ptr).format + else: + dumped = None + oid = self._none_oid + fmt = PQ_TEXT + + Py_INCREF(dumped) + PyList_SET_ITEM(out, i, dumped) + + Py_INCREF(oid) + PyTuple_SET_ITEM(types, i, oid) + + Py_INCREF(fmt) + PyList_SET_ITEM(pqformats, i, fmt) + + self.types = types + self.formats = pqformats + return out + + def load_rows(self, int row0, int row1, object make_row) -> List[Row]: + if self._pgresult is None: + raise e.InterfaceError("result not set") + + if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples): + raise e.InterfaceError( + f"rows must be included between 0 and {self._ntuples}" + ) + + cdef libpq.PGresult *res = self._pgresult._pgresult_ptr + # cheeky access to the internal PGresult structure + cdef pg_result_int *ires = <pg_result_int*>res + + cdef int row + cdef int col + cdef PGresAttValue *attval + cdef object record # not 'tuple' as it would check on assignment + + cdef object records = PyList_New(row1 - row0) + for row in range(row0, row1): + record = PyTuple_New(self._nfields) + Py_INCREF(record) + PyList_SET_ITEM(records, row - row0, record) + + cdef PyObject *loader # borrowed RowLoader + cdef PyObject *brecord # borrowed + row_loaders = self._row_loaders # avoid an incref/decref per item + + for col in range(self._nfields): + loader = PyList_GET_ITEM(row_loaders, col) + if (<RowLoader>loader).cloader is not None: + for row in range(row0, row1): + brecord = PyList_GET_ITEM(records, row - row0) + attval = &(ires.tuples[row][col]) + if attval.len == -1: # NULL_LEN + pyval = None + else: + pyval = (<RowLoader>loader).cloader.cload( + attval.value, attval.len) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(<object>brecord, col, pyval) + + else: + for row in range(row0, row1): + brecord = PyList_GET_ITEM(records, row - row0) + attval = &(ires.tuples[row][col]) + if attval.len == -1: # NULL_LEN + pyval = None + else: + b = PyMemoryView_FromObject( + ViewBuffer._from_buffer( + self._pgresult, + <unsigned char *>attval.value, attval.len)) + pyval = PyObject_CallFunctionObjArgs( + (<RowLoader>loader).loadfunc, <PyObject *>b, NULL) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(<object>brecord, col, pyval) + + if make_row is not tuple: + for i in range(row1 - row0): + brecord = PyList_GET_ITEM(records, i) + record = PyObject_CallFunctionObjArgs( + make_row, <PyObject *>brecord, NULL) + Py_INCREF(record) + PyList_SET_ITEM(records, i, record) + Py_DECREF(<object>brecord) + return records + + def load_row(self, int row, object make_row) -> Optional[Row]: + if self._pgresult is None: + return None + + if not 0 <= row < self._ntuples: + return None + + cdef libpq.PGresult *res = self._pgresult._pgresult_ptr + # cheeky access to the internal PGresult structure + cdef pg_result_int *ires = <pg_result_int*>res + + cdef PyObject *loader # borrowed RowLoader + cdef int col + cdef PGresAttValue *attval + cdef object record # not 'tuple' as it would check on assignment + + record = PyTuple_New(self._nfields) + row_loaders = self._row_loaders # avoid an incref/decref per item + + for col in range(self._nfields): + attval = &(ires.tuples[row][col]) + if attval.len == -1: # NULL_LEN + pyval = None + else: + loader = PyList_GET_ITEM(row_loaders, col) + if (<RowLoader>loader).cloader is not None: + pyval = (<RowLoader>loader).cloader.cload( + attval.value, attval.len) + else: + b = PyMemoryView_FromObject( + ViewBuffer._from_buffer( + self._pgresult, + <unsigned char *>attval.value, attval.len)) + pyval = PyObject_CallFunctionObjArgs( + (<RowLoader>loader).loadfunc, <PyObject *>b, NULL) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(record, col, pyval) + + if make_row is not tuple: + record = PyObject_CallFunctionObjArgs( + make_row, <PyObject *>record, NULL) + return record + + cpdef object load_sequence(self, record: Sequence[Optional[Buffer]]): + cdef Py_ssize_t nfields = len(record) + out = PyTuple_New(nfields) + cdef PyObject *loader # borrowed RowLoader + cdef int col + cdef char *ptr + cdef Py_ssize_t size + + row_loaders = self._row_loaders # avoid an incref/decref per item + if PyList_GET_SIZE(row_loaders) != nfields: + raise e.ProgrammingError( + f"cannot load sequence of {nfields} items:" + f" {len(self._row_loaders)} loaders registered") + + for col in range(nfields): + item = record[col] + if item is None: + Py_INCREF(None) + PyTuple_SET_ITEM(out, col, None) + continue + + loader = PyList_GET_ITEM(row_loaders, col) + if (<RowLoader>loader).cloader is not None: + _buffer_as_string_and_size(item, &ptr, &size) + pyval = (<RowLoader>loader).cloader.cload(ptr, size) + else: + pyval = PyObject_CallFunctionObjArgs( + (<RowLoader>loader).loadfunc, <PyObject *>item, NULL) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(out, col, pyval) + + return out + + def get_loader(self, oid: int, format: pq.Format) -> "Loader": + cdef PyObject *row_loader = self._c_get_loader( + <PyObject *>oid, <PyObject *>format) + return (<RowLoader>row_loader).pyloader + + cdef PyObject *_c_get_loader(self, PyObject *oid, PyObject *fmt) except NULL: + """ + Return a borrowed reference to the RowLoader instance for given oid/fmt + """ + cdef PyObject *ptr + cdef PyObject *cache + + if <object>fmt == PQ_TEXT: + if self._text_loaders is None: + self._text_loaders = {} + cache = <PyObject *>self._text_loaders + elif <object>fmt == PQ_BINARY: + if self._binary_loaders is None: + self._binary_loaders = {} + cache = <PyObject *>self._binary_loaders + else: + raise ValueError( + f"format should be a psycopg.pq.Format, not {format}") + + ptr = PyDict_GetItem(<object>cache, <object>oid) + if ptr != NULL: + return ptr + + loader_cls = self.adapters.get_loader(<object>oid, <object>fmt) + if loader_cls is None: + loader_cls = self.adapters.get_loader(oids.INVALID_OID, <object>fmt) + if loader_cls is None: + raise e.InterfaceError("unknown oid loader not found") + + loader = PyObject_CallFunctionObjArgs( + loader_cls, oid, <PyObject *>self, NULL) + + cdef RowLoader row_loader = RowLoader() + row_loader.pyloader = loader + row_loader.loadfunc = loader.load + if isinstance(loader, CLoader): + row_loader.cloader = <CLoader>loader + + PyDict_SetItem(<object>cache, <object>oid, row_loader) + return <PyObject *>row_loader + + +cdef object _as_row_dumper(object dumper): + cdef RowDumper row_dumper = RowDumper() + + row_dumper.pydumper = dumper + row_dumper.dumpfunc = dumper.dump + row_dumper.oid = dumper.oid + row_dumper.format = dumper.format + + if isinstance(dumper, CDumper): + row_dumper.cdumper = <CDumper>dumper + + return row_dumper + + +cdef Transformer _tx_from_context(object context): + if isinstance(context, Transformer): + return context + else: + return Transformer(context) diff --git a/psycopg_c/psycopg_c/_psycopg/waiting.pyx b/psycopg_c/psycopg_c/_psycopg/waiting.pyx new file mode 100644 index 0000000..0af6c57 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/waiting.pyx @@ -0,0 +1,197 @@ +""" +C implementation of waiting functions +""" + +# Copyright (C) 2022 The Psycopg Team + +from cpython.object cimport PyObject_CallFunctionObjArgs + +cdef extern from *: + """ +#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL) + +#if defined(HAVE_POLL_H) +#include <poll.h> +#elif defined(HAVE_SYS_POLL_H) +#include <sys/poll.h> +#endif + +#else /* no poll available */ + +#ifdef MS_WINDOWS +#include <winsock2.h> +#else +#include <sys/select.h> +#endif + +#endif /* HAVE_POLL */ + +#define SELECT_EV_READ 1 +#define SELECT_EV_WRITE 2 + +#define SEC_TO_MS 1000 +#define SEC_TO_US (1000 * 1000) + +/* Use select to wait for readiness on fileno. + * + * - Return SELECT_EV_* if the file is ready + * - Return 0 on timeout + * - Return -1 (and set an exception) on error. + * + * The wisdom of this function comes from: + * + * - PostgreSQL libpq (see src/interfaces/libpq/fe-misc.c) + * - Python select module (see Modules/selectmodule.c) + */ +static int +wait_c_impl(int fileno, int wait, float timeout) +{ + int select_rv; + int rv = 0; + +#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL) + + struct pollfd input_fd; + int timeout_ms; + + input_fd.fd = fileno; + input_fd.events = POLLERR; + input_fd.revents = 0; + + if (wait & SELECT_EV_READ) { input_fd.events |= POLLIN; } + if (wait & SELECT_EV_WRITE) { input_fd.events |= POLLOUT; } + + if (timeout < 0.0) { + timeout_ms = -1; + } else { + timeout_ms = (int)(timeout * SEC_TO_MS); + } + + Py_BEGIN_ALLOW_THREADS + errno = 0; + select_rv = poll(&input_fd, 1, timeout_ms); + Py_END_ALLOW_THREADS + + if (PyErr_CheckSignals()) { goto finally; } + + if (select_rv < 0) { + goto error; + } + + if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; } + if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; } + +#else + + fd_set ifds; + fd_set ofds; + fd_set efds; + struct timeval tv, *tvptr; + +#ifndef MS_WINDOWS + if (fileno >= 1024) { + PyErr_SetString( + PyExc_ValueError, /* same exception of Python's 'select.select()' */ + "connection file descriptor out of range for 'select()'"); + return -1; + } +#endif + + FD_ZERO(&ifds); + FD_ZERO(&ofds); + FD_ZERO(&efds); + + if (wait & SELECT_EV_READ) { FD_SET(fileno, &ifds); } + if (wait & SELECT_EV_WRITE) { FD_SET(fileno, &ofds); } + FD_SET(fileno, &efds); + + /* Compute appropriate timeout interval */ + if (timeout < 0.0) { + tvptr = NULL; + } + else { + tv.tv_sec = (int)timeout; + tv.tv_usec = (int)(((long)timeout * SEC_TO_US) % SEC_TO_US); + tvptr = &tv; + } + + Py_BEGIN_ALLOW_THREADS + errno = 0; + select_rv = select(fileno + 1, &ifds, &ofds, &efds, tvptr); + Py_END_ALLOW_THREADS + + if (PyErr_CheckSignals()) { goto finally; } + + if (select_rv < 0) { + goto error; + } + + if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; } + if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; } + +#endif /* HAVE_POLL */ + + return rv; + +error: + +#ifdef MS_WINDOWS + if (select_rv == SOCKET_ERROR) { + PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError()); + } +#else + if (select_rv < 0) { + PyErr_SetFromErrno(PyExc_OSError); + } +#endif + else { + PyErr_SetString(PyExc_OSError, "unexpected error from select()"); + } + +finally: + + return -1; + +} + """ + cdef int wait_c_impl(int fileno, int wait, float timeout) except -1 + + +def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV: + """ + Wait for a generator using poll or select. + """ + cdef float ctimeout + cdef int wait, ready + cdef PyObject *pyready + + if timeout is None: + ctimeout = -1.0 + else: + ctimeout = float(timeout) + if ctimeout < 0.0: + ctimeout = -1.0 + + send = gen.send + + try: + wait = next(gen) + + while True: + ready = wait_c_impl(fileno, wait, ctimeout) + if ready == 0: + continue + elif ready == READY_R: + pyready = <PyObject *>PY_READY_R + elif ready == READY_RW: + pyready = <PyObject *>PY_READY_RW + elif ready == READY_W: + pyready = <PyObject *>PY_READY_W + else: + raise AssertionError(f"unexpected ready value: {ready}") + + wait = PyObject_CallFunctionObjArgs(send, pyready, NULL) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv diff --git a/psycopg_c/psycopg_c/pq.pxd b/psycopg_c/psycopg_c/pq.pxd new file mode 100644 index 0000000..57825dd --- /dev/null +++ b/psycopg_c/psycopg_c/pq.pxd @@ -0,0 +1,78 @@ +# Include pid_t but Windows doesn't have it +# Don't use "IF" so that the generated C is portable and can be included +# in the sdist. +cdef extern from * nogil: + """ +#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS) + typedef signed pid_t; +#else + #include <fcntl.h> +#endif + """ + ctypedef signed pid_t + +from psycopg_c.pq cimport libpq + +ctypedef char *(*conn_bytes_f) (const libpq.PGconn *) +ctypedef int(*conn_int_f) (const libpq.PGconn *) + + +cdef class PGconn: + cdef libpq.PGconn* _pgconn_ptr + cdef object __weakref__ + cdef public object notice_handler + cdef public object notify_handler + cdef pid_t _procpid + + @staticmethod + cdef PGconn _from_ptr(libpq.PGconn *ptr) + + cpdef int flush(self) except -1 + cpdef object notifies(self) + + +cdef class PGresult: + cdef libpq.PGresult* _pgresult_ptr + + @staticmethod + cdef PGresult _from_ptr(libpq.PGresult *ptr) + + +cdef class PGcancel: + cdef libpq.PGcancel* pgcancel_ptr + + @staticmethod + cdef PGcancel _from_ptr(libpq.PGcancel *ptr) + + +cdef class Escaping: + cdef PGconn conn + + cpdef escape_literal(self, data) + cpdef escape_identifier(self, data) + cpdef escape_string(self, data) + cpdef escape_bytea(self, data) + cpdef unescape_bytea(self, const unsigned char *data) + + +cdef class PQBuffer: + cdef unsigned char *buf + cdef Py_ssize_t len + + @staticmethod + cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length) + + +cdef class ViewBuffer: + cdef unsigned char *buf + cdef Py_ssize_t len + cdef object obj + + @staticmethod + cdef ViewBuffer _from_buffer( + object obj, unsigned char *buf, Py_ssize_t length) + + +cdef int _buffer_as_string_and_size( + data: "Buffer", char **ptr, Py_ssize_t *length +) except -1 diff --git a/psycopg_c/psycopg_c/pq.pyx b/psycopg_c/psycopg_c/pq.pyx new file mode 100644 index 0000000..d397c17 --- /dev/null +++ b/psycopg_c/psycopg_c/pq.pyx @@ -0,0 +1,38 @@ +""" +libpq Python wrapper using cython bindings. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c.pq cimport libpq + +import logging + +from psycopg import errors as e +from psycopg.pq import Format +from psycopg.pq.misc import error_message + +logger = logging.getLogger("psycopg") + +__impl__ = 'c' +__build_version__ = libpq.PG_VERSION_NUM + + +def version(): + return libpq.PQlibVersion() + + +include "pq/pgconn.pyx" +include "pq/pgresult.pyx" +include "pq/pgcancel.pyx" +include "pq/conninfo.pyx" +include "pq/escaping.pyx" +include "pq/pqbuffer.pyx" + + +# importing the ssl module sets up Python's libcrypto callbacks +import ssl # noqa + +# disable libcrypto setup in libpq, so it won't stomp on the callbacks +# that have already been set up +libpq.PQinitOpenSSL(1, 0) diff --git a/psycopg_c/psycopg_c/pq/__init__.pxd b/psycopg_c/psycopg_c/pq/__init__.pxd new file mode 100644 index 0000000..ce8c60c --- /dev/null +++ b/psycopg_c/psycopg_c/pq/__init__.pxd @@ -0,0 +1,9 @@ +""" +psycopg_c.pq cython module. + +This file is necessary to allow c-importing pxd files from this directory. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c.pq cimport libpq diff --git a/psycopg_c/psycopg_c/pq/conninfo.pyx b/psycopg_c/psycopg_c/pq/conninfo.pyx new file mode 100644 index 0000000..3443de1 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/conninfo.pyx @@ -0,0 +1,61 @@ +""" +psycopg_c.pq.Conninfo object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg.pq.misc import ConninfoOption + + +class Conninfo: + @classmethod + def get_defaults(cls) -> List[ConninfoOption]: + cdef libpq.PQconninfoOption *opts = libpq.PQconndefaults() + if opts is NULL : + raise MemoryError("couldn't allocate connection defaults") + rv = _options_from_array(opts) + libpq.PQconninfoFree(opts) + return rv + + @classmethod + def parse(cls, const char *conninfo) -> List[ConninfoOption]: + cdef char *errmsg = NULL + cdef libpq.PQconninfoOption *opts = libpq.PQconninfoParse(conninfo, &errmsg) + if opts is NULL: + if errmsg is NULL: + raise MemoryError("couldn't allocate on conninfo parse") + else: + exc = e.OperationalError(errmsg.decode("utf8", "replace")) + libpq.PQfreemem(errmsg) + raise exc + + rv = _options_from_array(opts) + libpq.PQconninfoFree(opts) + return rv + + def __repr__(self): + return f"<{type(self).__name__} ({self.keyword.decode('ascii')})>" + + +cdef _options_from_array(libpq.PQconninfoOption *opts): + rv = [] + cdef int i = 0 + cdef libpq.PQconninfoOption* opt + while True: + opt = opts + i + if opt.keyword is NULL: + break + rv.append( + ConninfoOption( + keyword=opt.keyword, + envvar=opt.envvar if opt.envvar is not NULL else None, + compiled=opt.compiled if opt.compiled is not NULL else None, + val=opt.val if opt.val is not NULL else None, + label=opt.label if opt.label is not NULL else None, + dispchar=opt.dispchar if opt.dispchar is not NULL else None, + dispsize=opt.dispsize, + ) + ) + i += 1 + + return rv diff --git a/psycopg_c/psycopg_c/pq/escaping.pyx b/psycopg_c/psycopg_c/pq/escaping.pyx new file mode 100644 index 0000000..f0a44d3 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/escaping.pyx @@ -0,0 +1,132 @@ +""" +psycopg_c.pq.Escaping object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +from libc.string cimport strlen +from cpython.mem cimport PyMem_Malloc, PyMem_Free + + +cdef class Escaping: + def __init__(self, PGconn conn = None): + self.conn = conn + + cpdef escape_literal(self, data): + cdef char *out + cdef char *ptr + cdef Py_ssize_t length + + if self.conn is None: + raise e.OperationalError("escape_literal failed: no connection provided") + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + _buffer_as_string_and_size(data, &ptr, &length) + + out = libpq.PQescapeLiteral(self.conn._pgconn_ptr, ptr, length) + if out is NULL: + raise e.OperationalError( + f"escape_literal failed: {error_message(self.conn)}" + ) + + rv = out[:strlen(out)] + libpq.PQfreemem(out) + return rv + + cpdef escape_identifier(self, data): + cdef char *out + cdef char *ptr + cdef Py_ssize_t length + + _buffer_as_string_and_size(data, &ptr, &length) + + if self.conn is None: + raise e.OperationalError("escape_identifier failed: no connection provided") + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + out = libpq.PQescapeIdentifier(self.conn._pgconn_ptr, ptr, length) + if out is NULL: + raise e.OperationalError( + f"escape_identifier failed: {error_message(self.conn)}" + ) + + rv = out[:strlen(out)] + libpq.PQfreemem(out) + return rv + + cpdef escape_string(self, data): + cdef int error + cdef size_t len_out + cdef char *ptr + cdef char *buf_out + cdef Py_ssize_t length + + _buffer_as_string_and_size(data, &ptr, &length) + + if self.conn is not None: + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + buf_out = <char *>PyMem_Malloc(length * 2 + 1) + len_out = libpq.PQescapeStringConn( + self.conn._pgconn_ptr, buf_out, ptr, length, &error + ) + if error: + PyMem_Free(buf_out) + raise e.OperationalError( + f"escape_string failed: {error_message(self.conn)}" + ) + + else: + buf_out = <char *>PyMem_Malloc(length * 2 + 1) + len_out = libpq.PQescapeString(buf_out, ptr, length) + + rv = buf_out[:len_out] + PyMem_Free(buf_out) + return rv + + cpdef escape_bytea(self, data): + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + + if self.conn is not None and self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + _buffer_as_string_and_size(data, &ptr, &length) + + if self.conn is not None: + out = libpq.PQescapeByteaConn( + self.conn._pgconn_ptr, <unsigned char *>ptr, length, &len_out) + else: + out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out) + + if out is NULL: + raise MemoryError( + f"couldn't allocate for escape_bytea of {len(data)} bytes" + ) + + rv = out[:len_out - 1] # out includes final 0 + libpq.PQfreemem(out) + return rv + + cpdef unescape_bytea(self, const unsigned char *data): + # not needed, but let's keep it symmetric with the escaping: + # if a connection is passed in, it must be valid. + if self.conn is not None: + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + cdef size_t len_out + cdef unsigned char *out = libpq.PQunescapeBytea(data, &len_out) + if out is NULL: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = out[:len_out] + libpq.PQfreemem(out) + return rv diff --git a/psycopg_c/psycopg_c/pq/libpq.pxd b/psycopg_c/psycopg_c/pq/libpq.pxd new file mode 100644 index 0000000..5e05e40 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/libpq.pxd @@ -0,0 +1,321 @@ +""" +Libpq header definition for the cython psycopg.pq implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cdef extern from "stdio.h": + + ctypedef struct FILE: + pass + +cdef extern from "pg_config.h": + + int PG_VERSION_NUM + + +cdef extern from "libpq-fe.h": + + # structures and types + + ctypedef unsigned int Oid + + ctypedef struct PGconn: + pass + + ctypedef struct PGresult: + pass + + ctypedef struct PQconninfoOption: + char *keyword + char *envvar + char *compiled + char *val + char *label + char *dispchar + int dispsize + + ctypedef struct PGnotify: + char *relname + int be_pid + char *extra + + ctypedef struct PGcancel: + pass + + ctypedef struct PGresAttDesc: + char *name + Oid tableid + int columnid + int format + Oid typid + int typlen + int atttypmod + + # enums + + ctypedef enum PostgresPollingStatusType: + PGRES_POLLING_FAILED = 0 + PGRES_POLLING_READING + PGRES_POLLING_WRITING + PGRES_POLLING_OK + PGRES_POLLING_ACTIVE + + + ctypedef enum PGPing: + PQPING_OK + PQPING_REJECT + PQPING_NO_RESPONSE + PQPING_NO_ATTEMPT + + ctypedef enum ConnStatusType: + CONNECTION_OK + CONNECTION_BAD + CONNECTION_STARTED + CONNECTION_MADE + CONNECTION_AWAITING_RESPONSE + CONNECTION_AUTH_OK + CONNECTION_SETENV + CONNECTION_SSL_STARTUP + CONNECTION_NEEDED + CONNECTION_CHECK_WRITABLE + CONNECTION_GSS_STARTUP + # CONNECTION_CHECK_TARGET PG 12 + + ctypedef enum PGTransactionStatusType: + PQTRANS_IDLE + PQTRANS_ACTIVE + PQTRANS_INTRANS + PQTRANS_INERROR + PQTRANS_UNKNOWN + + ctypedef enum ExecStatusType: + PGRES_EMPTY_QUERY = 0 + PGRES_COMMAND_OK + PGRES_TUPLES_OK + PGRES_COPY_OUT + PGRES_COPY_IN + PGRES_BAD_RESPONSE + PGRES_NONFATAL_ERROR + PGRES_FATAL_ERROR + PGRES_COPY_BOTH + PGRES_SINGLE_TUPLE + PGRES_PIPELINE_SYNC + PGRES_PIPELINE_ABORT + + # 33.1. Database Connection Control Functions + PGconn *PQconnectdb(const char *conninfo) + PGconn *PQconnectStart(const char *conninfo) + PostgresPollingStatusType PQconnectPoll(PGconn *conn) nogil + PQconninfoOption *PQconndefaults() + PQconninfoOption *PQconninfo(PGconn *conn) + PQconninfoOption *PQconninfoParse(const char *conninfo, char **errmsg) + void PQfinish(PGconn *conn) + void PQreset(PGconn *conn) + int PQresetStart(PGconn *conn) + PostgresPollingStatusType PQresetPoll(PGconn *conn) + PGPing PQping(const char *conninfo) + + # 33.2. Connection Status Functions + char *PQdb(const PGconn *conn) + char *PQuser(const PGconn *conn) + char *PQpass(const PGconn *conn) + char *PQhost(const PGconn *conn) + char *PQhostaddr(const PGconn *conn) + char *PQport(const PGconn *conn) + char *PQtty(const PGconn *conn) + char *PQoptions(const PGconn *conn) + ConnStatusType PQstatus(const PGconn *conn) + PGTransactionStatusType PQtransactionStatus(const PGconn *conn) + const char *PQparameterStatus(const PGconn *conn, const char *paramName) + int PQprotocolVersion(const PGconn *conn) + int PQserverVersion(const PGconn *conn) + char *PQerrorMessage(const PGconn *conn) + int PQsocket(const PGconn *conn) nogil + int PQbackendPID(const PGconn *conn) + int PQconnectionNeedsPassword(const PGconn *conn) + int PQconnectionUsedPassword(const PGconn *conn) + int PQsslInUse(PGconn *conn) # TODO: const in PG 12 docs - verify/report + # TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl + + # 33.3. Command Execution Functions + PGresult *PQexec(PGconn *conn, const char *command) nogil + PGresult *PQexecParams(PGconn *conn, + const char *command, + int nParams, + const Oid *paramTypes, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + PGresult *PQprepare(PGconn *conn, + const char *stmtName, + const char *query, + int nParams, + const Oid *paramTypes) nogil + PGresult *PQexecPrepared(PGconn *conn, + const char *stmtName, + int nParams, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + PGresult *PQdescribePrepared(PGconn *conn, const char *stmtName) nogil + PGresult *PQdescribePortal(PGconn *conn, const char *portalName) nogil + ExecStatusType PQresultStatus(const PGresult *res) nogil + # PQresStatus: not needed, we have pretty enums + char *PQresultErrorMessage(const PGresult *res) nogil + # TODO: PQresultVerboseErrorMessage + char *PQresultErrorField(const PGresult *res, int fieldcode) nogil + void PQclear(PGresult *res) nogil + + # 33.3.2. Retrieving Query Result Information + int PQntuples(const PGresult *res) + int PQnfields(const PGresult *res) + char *PQfname(const PGresult *res, int column_number) + int PQfnumber(const PGresult *res, const char *column_name) + Oid PQftable(const PGresult *res, int column_number) + int PQftablecol(const PGresult *res, int column_number) + int PQfformat(const PGresult *res, int column_number) + Oid PQftype(const PGresult *res, int column_number) + int PQfmod(const PGresult *res, int column_number) + int PQfsize(const PGresult *res, int column_number) + int PQbinaryTuples(const PGresult *res) + char *PQgetvalue(const PGresult *res, int row_number, int column_number) + int PQgetisnull(const PGresult *res, int row_number, int column_number) + int PQgetlength(const PGresult *res, int row_number, int column_number) + int PQnparams(const PGresult *res) + Oid PQparamtype(const PGresult *res, int param_number) + # PQprint: pretty useless + + # 33.3.3. Retrieving Other Result Information + char *PQcmdStatus(PGresult *res) + char *PQcmdTuples(PGresult *res) + Oid PQoidValue(const PGresult *res) + + # 33.3.4. Escaping Strings for Inclusion in SQL Commands + char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length) + char *PQescapeLiteral(PGconn *conn, const char *str, size_t length) + size_t PQescapeStringConn(PGconn *conn, + char *to, const char *from_, size_t length, + int *error) + size_t PQescapeString(char *to, const char *from_, size_t length) + unsigned char *PQescapeByteaConn(PGconn *conn, + const unsigned char *src, + size_t from_length, + size_t *to_length) + unsigned char *PQescapeBytea(const unsigned char *src, + size_t from_length, + size_t *to_length) + unsigned char *PQunescapeBytea(const unsigned char *src, size_t *to_length) + + + # 33.4. Asynchronous Command Processing + int PQsendQuery(PGconn *conn, const char *command) nogil + int PQsendQueryParams(PGconn *conn, + const char *command, + int nParams, + const Oid *paramTypes, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + int PQsendPrepare(PGconn *conn, + const char *stmtName, + const char *query, + int nParams, + const Oid *paramTypes) nogil + int PQsendQueryPrepared(PGconn *conn, + const char *stmtName, + int nParams, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + int PQsendDescribePrepared(PGconn *conn, const char *stmtName) nogil + int PQsendDescribePortal(PGconn *conn, const char *portalName) nogil + PGresult *PQgetResult(PGconn *conn) nogil + int PQconsumeInput(PGconn *conn) nogil + int PQisBusy(PGconn *conn) nogil + int PQsetnonblocking(PGconn *conn, int arg) nogil + int PQisnonblocking(const PGconn *conn) + int PQflush(PGconn *conn) nogil + + # 33.5. Retrieving Query Results Row-by-Row + int PQsetSingleRowMode(PGconn *conn) + + # 33.6. Canceling Queries in Progress + PGcancel *PQgetCancel(PGconn *conn) + void PQfreeCancel(PGcancel *cancel) + int PQcancel(PGcancel *cancel, char *errbuf, int errbufsize) + + # 33.8. Asynchronous Notification + PGnotify *PQnotifies(PGconn *conn) nogil + + # 33.9. Functions Associated with the COPY Command + int PQputCopyData(PGconn *conn, const char *buffer, int nbytes) nogil + int PQputCopyEnd(PGconn *conn, const char *errormsg) nogil + int PQgetCopyData(PGconn *conn, char **buffer, int async) nogil + + # 33.10. Control Functions + void PQtrace(PGconn *conn, FILE *stream); + void PQsetTraceFlags(PGconn *conn, int flags); + void PQuntrace(PGconn *conn); + + # 33.11. Miscellaneous Functions + void PQfreemem(void *ptr) nogil + void PQconninfoFree(PQconninfoOption *connOptions) + char *PQencryptPasswordConn( + PGconn *conn, const char *passwd, const char *user, const char *algorithm); + PGresult *PQmakeEmptyPGresult(PGconn *conn, ExecStatusType status) + int PQsetResultAttrs(PGresult *res, int numAttributes, PGresAttDesc *attDescs) + int PQlibVersion() + + # 33.12. Notice Processing + ctypedef void (*PQnoticeReceiver)(void *arg, const PGresult *res) + PQnoticeReceiver PQsetNoticeReceiver( + PGconn *conn, PQnoticeReceiver prog, void *arg) + + # 33.18. SSL Support + void PQinitOpenSSL(int do_ssl, int do_crypto) + + # 34.5 Pipeline Mode + + ctypedef enum PGpipelineStatus: + PQ_PIPELINE_OFF + PQ_PIPELINE_ON + PQ_PIPELINE_ABORTED + + PGpipelineStatus PQpipelineStatus(const PGconn *conn) + int PQenterPipelineMode(PGconn *conn) + int PQexitPipelineMode(PGconn *conn) + int PQpipelineSync(PGconn *conn) + int PQsendFlushRequest(PGconn *conn) + +cdef extern from *: + """ +/* Hack to allow the use of old libpq versions */ +#if PG_VERSION_NUM < 100000 +#define PQencryptPasswordConn(conn, passwd, user, algorithm) NULL +#endif + +#if PG_VERSION_NUM < 120000 +#define PQhostaddr(conn) NULL +#endif + +#if PG_VERSION_NUM < 140000 +#define PGRES_PIPELINE_SYNC 10 +#define PGRES_PIPELINE_ABORTED 11 +typedef enum { + PQ_PIPELINE_OFF, + PQ_PIPELINE_ON, + PQ_PIPELINE_ABORTED +} PGpipelineStatus; +#define PQpipelineStatus(conn) PQ_PIPELINE_OFF +#define PQenterPipelineMode(conn) 0 +#define PQexitPipelineMode(conn) 1 +#define PQpipelineSync(conn) 0 +#define PQsendFlushRequest(conn) 0 +#define PQsetTraceFlags(conn, stream) do {} while (0) +#endif +""" diff --git a/psycopg_c/psycopg_c/pq/pgcancel.pyx b/psycopg_c/psycopg_c/pq/pgcancel.pyx new file mode 100644 index 0000000..b7cbb70 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pgcancel.pyx @@ -0,0 +1,32 @@ +""" +psycopg_c.pq.PGcancel object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + + +cdef class PGcancel: + def __cinit__(self): + self.pgcancel_ptr = NULL + + @staticmethod + cdef PGcancel _from_ptr(libpq.PGcancel *ptr): + cdef PGcancel rv = PGcancel.__new__(PGcancel) + rv.pgcancel_ptr = ptr + return rv + + def __dealloc__(self) -> None: + self.free() + + def free(self) -> None: + if self.pgcancel_ptr is not NULL: + libpq.PQfreeCancel(self.pgcancel_ptr) + self.pgcancel_ptr = NULL + + def cancel(self) -> None: + cdef char buf[256] + cdef int res = libpq.PQcancel(self.pgcancel_ptr, buf, sizeof(buf)) + if not res: + raise e.OperationalError( + f"cancel failed: {buf.decode('utf8', 'ignore')}" + ) diff --git a/psycopg_c/psycopg_c/pq/pgconn.pyx b/psycopg_c/psycopg_c/pq/pgconn.pyx new file mode 100644 index 0000000..4a60530 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pgconn.pyx @@ -0,0 +1,733 @@ +""" +psycopg_c.pq.PGconn object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cdef extern from * nogil: + """ +#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS) + /* We don't need a real definition for this because Windows is not affected + * by the issue caused by closing the fds after fork. + */ + #define getpid() (0) +#else + #include <unistd.h> +#endif + """ + pid_t getpid() + +from libc.stdio cimport fdopen +from cpython.mem cimport PyMem_Malloc, PyMem_Free +from cpython.bytes cimport PyBytes_AsString +from cpython.memoryview cimport PyMemoryView_FromObject + +import sys + +from psycopg.pq import Format as PqFormat, Trace +from psycopg.pq.misc import PGnotify, connection_summary +from psycopg_c.pq cimport PQBuffer + + +cdef class PGconn: + @staticmethod + cdef PGconn _from_ptr(libpq.PGconn *ptr): + cdef PGconn rv = PGconn.__new__(PGconn) + rv._pgconn_ptr = ptr + + libpq.PQsetNoticeReceiver(ptr, notice_receiver, <void *>rv) + return rv + + def __cinit__(self): + self._pgconn_ptr = NULL + self._procpid = getpid() + + def __dealloc__(self): + # Close the connection only if it was created in this process, + # not if this object is being GC'd after fork. + if self._procpid == getpid(): + self.finish() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self) + return f"<{cls} {info} at 0x{id(self):x}>" + + @classmethod + def connect(cls, const char *conninfo) -> PGconn: + cdef libpq.PGconn* pgconn = libpq.PQconnectdb(conninfo) + if not pgconn: + raise MemoryError("couldn't allocate PGconn") + + return PGconn._from_ptr(pgconn) + + @classmethod + def connect_start(cls, const char *conninfo) -> PGconn: + cdef libpq.PGconn* pgconn = libpq.PQconnectStart(conninfo) + if not pgconn: + raise MemoryError("couldn't allocate PGconn") + + return PGconn._from_ptr(pgconn) + + def connect_poll(self) -> int: + return _call_int(self, <conn_int_f>libpq.PQconnectPoll) + + def finish(self) -> None: + if self._pgconn_ptr is not NULL: + libpq.PQfinish(self._pgconn_ptr) + self._pgconn_ptr = NULL + + @property + def pgconn_ptr(self) -> Optional[int]: + if self._pgconn_ptr: + return <long long><void *>self._pgconn_ptr + else: + return None + + @property + def info(self) -> List["ConninfoOption"]: + _ensure_pgconn(self) + cdef libpq.PQconninfoOption *opts = libpq.PQconninfo(self._pgconn_ptr) + if opts is NULL: + raise MemoryError("couldn't allocate connection info") + rv = _options_from_array(opts) + libpq.PQconninfoFree(opts) + return rv + + def reset(self) -> None: + _ensure_pgconn(self) + libpq.PQreset(self._pgconn_ptr) + + def reset_start(self) -> None: + if not libpq.PQresetStart(self._pgconn_ptr): + raise e.OperationalError("couldn't reset connection") + + def reset_poll(self) -> int: + return _call_int(self, <conn_int_f>libpq.PQresetPoll) + + @classmethod + def ping(self, const char *conninfo) -> int: + return libpq.PQping(conninfo) + + @property + def db(self) -> bytes: + return _call_bytes(self, libpq.PQdb) + + @property + def user(self) -> bytes: + return _call_bytes(self, libpq.PQuser) + + @property + def password(self) -> bytes: + return _call_bytes(self, libpq.PQpass) + + @property + def host(self) -> bytes: + return _call_bytes(self, libpq.PQhost) + + @property + def hostaddr(self) -> bytes: + if libpq.PG_VERSION_NUM < 120000: + raise e.NotSupportedError( + f"PQhostaddr requires libpq from PostgreSQL 12," + f" {libpq.PG_VERSION_NUM} available instead" + ) + + _ensure_pgconn(self) + cdef char *rv = libpq.PQhostaddr(self._pgconn_ptr) + assert rv is not NULL + return rv + + @property + def port(self) -> bytes: + return _call_bytes(self, libpq.PQport) + + @property + def tty(self) -> bytes: + return _call_bytes(self, libpq.PQtty) + + @property + def options(self) -> bytes: + return _call_bytes(self, libpq.PQoptions) + + @property + def status(self) -> int: + return libpq.PQstatus(self._pgconn_ptr) + + @property + def transaction_status(self) -> int: + return libpq.PQtransactionStatus(self._pgconn_ptr) + + def parameter_status(self, const char *name) -> Optional[bytes]: + _ensure_pgconn(self) + cdef const char *rv = libpq.PQparameterStatus(self._pgconn_ptr, name) + if rv is not NULL: + return rv + else: + return None + + @property + def error_message(self) -> bytes: + return libpq.PQerrorMessage(self._pgconn_ptr) + + @property + def protocol_version(self) -> int: + return _call_int(self, libpq.PQprotocolVersion) + + @property + def server_version(self) -> int: + return _call_int(self, libpq.PQserverVersion) + + @property + def socket(self) -> int: + rv = _call_int(self, libpq.PQsocket) + if rv == -1: + raise e.OperationalError("the connection is lost") + return rv + + @property + def backend_pid(self) -> int: + return _call_int(self, libpq.PQbackendPID) + + @property + def needs_password(self) -> bool: + return bool(libpq.PQconnectionNeedsPassword(self._pgconn_ptr)) + + @property + def used_password(self) -> bool: + return bool(libpq.PQconnectionUsedPassword(self._pgconn_ptr)) + + @property + def ssl_in_use(self) -> bool: + return bool(_call_int(self, <conn_int_f>libpq.PQsslInUse)) + + def exec_(self, const char *command) -> PGresult: + _ensure_pgconn(self) + cdef libpq.PGresult *pgresult + with nogil: + pgresult = libpq.PQexec(self._pgconn_ptr, command) + if pgresult is NULL: + raise MemoryError("couldn't allocate PGresult") + + return PGresult._from_ptr(pgresult) + + def send_query(self, const char *command) -> None: + _ensure_pgconn(self) + cdef int rv + with nogil: + rv = libpq.PQsendQuery(self._pgconn_ptr, command) + if not rv: + raise e.OperationalError(f"sending query failed: {error_message(self)}") + + def exec_params( + self, + const char *command, + param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> PGresult: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, param_types, param_formats) + + cdef libpq.PGresult *pgresult + with nogil: + pgresult = libpq.PQexecParams( + self._pgconn_ptr, command, <int>cnparams, ctypes, + <const char *const *>cvalues, clengths, cformats, result_format) + _clear_query_params(ctypes, cvalues, clengths, cformats) + if pgresult is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(pgresult) + + def send_query_params( + self, + const char *command, + param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> None: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, param_types, param_formats) + + cdef int rv + with nogil: + rv = libpq.PQsendQueryParams( + self._pgconn_ptr, command, <int>cnparams, ctypes, + <const char *const *>cvalues, clengths, cformats, result_format) + _clear_query_params(ctypes, cvalues, clengths, cformats) + if not rv: + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_prepare( + self, + const char *name, + const char *command, + param_types: Optional[Sequence[int]] = None, + ) -> None: + _ensure_pgconn(self) + + cdef int i + cdef Py_ssize_t nparams = len(param_types) if param_types else 0 + cdef libpq.Oid *atypes = NULL + if nparams: + atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid)) + for i in range(nparams): + atypes[i] = param_types[i] + + cdef int rv + with nogil: + rv = libpq.PQsendPrepare( + self._pgconn_ptr, name, command, <int>nparams, atypes + ) + PyMem_Free(atypes) + if not rv: + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_query_prepared( + self, + const char *name, + param_values: Optional[Sequence[Optional[bytes]]], + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> None: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, None, param_formats) + + cdef int rv + with nogil: + rv = libpq.PQsendQueryPrepared( + self._pgconn_ptr, name, <int>cnparams, <const char *const *>cvalues, + clengths, cformats, result_format) + _clear_query_params(ctypes, cvalues, clengths, cformats) + if not rv: + raise e.OperationalError( + f"sending prepared query failed: {error_message(self)}" + ) + + def prepare( + self, + const char *name, + const char *command, + param_types: Optional[Sequence[int]] = None, + ) -> PGresult: + _ensure_pgconn(self) + + cdef int i + cdef Py_ssize_t nparams = len(param_types) if param_types else 0 + cdef libpq.Oid *atypes = NULL + if nparams: + atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid)) + for i in range(nparams): + atypes[i] = param_types[i] + + cdef libpq.PGresult *rv + with nogil: + rv = libpq.PQprepare( + self._pgconn_ptr, name, command, <int>nparams, atypes) + PyMem_Free(atypes) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def exec_prepared( + self, + const char *name, + param_values: Optional[Sequence[bytes]], + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> PGresult: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, None, param_formats) + + cdef libpq.PGresult *rv + with nogil: + rv = libpq.PQexecPrepared( + self._pgconn_ptr, name, <int>cnparams, + <const char *const *>cvalues, + clengths, cformats, result_format) + + _clear_query_params(ctypes, cvalues, clengths, cformats) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def describe_prepared(self, const char *name) -> PGresult: + _ensure_pgconn(self) + cdef libpq.PGresult *rv = libpq.PQdescribePrepared(self._pgconn_ptr, name) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def send_describe_prepared(self, const char *name) -> None: + _ensure_pgconn(self) + cdef int rv = libpq.PQsendDescribePrepared(self._pgconn_ptr, name) + if not rv: + raise e.OperationalError( + f"sending describe prepared failed: {error_message(self)}" + ) + + def describe_portal(self, const char *name) -> PGresult: + _ensure_pgconn(self) + cdef libpq.PGresult *rv = libpq.PQdescribePortal(self._pgconn_ptr, name) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def send_describe_portal(self, const char *name) -> None: + _ensure_pgconn(self) + cdef int rv = libpq.PQsendDescribePortal(self._pgconn_ptr, name) + if not rv: + raise e.OperationalError( + f"sending describe prepared failed: {error_message(self)}" + ) + + def get_result(self) -> Optional["PGresult"]: + cdef libpq.PGresult *pgresult = libpq.PQgetResult(self._pgconn_ptr) + if pgresult is NULL: + return None + return PGresult._from_ptr(pgresult) + + def consume_input(self) -> None: + if 1 != libpq.PQconsumeInput(self._pgconn_ptr): + raise e.OperationalError(f"consuming input failed: {error_message(self)}") + + def is_busy(self) -> int: + cdef int rv + with nogil: + rv = libpq.PQisBusy(self._pgconn_ptr) + return rv + + @property + def nonblocking(self) -> int: + return libpq.PQisnonblocking(self._pgconn_ptr) + + @nonblocking.setter + def nonblocking(self, int arg) -> None: + if 0 > libpq.PQsetnonblocking(self._pgconn_ptr, arg): + raise e.OperationalError(f"setting nonblocking failed: {error_message(self)}") + + cpdef int flush(self) except -1: + if self._pgconn_ptr == NULL: + raise e.OperationalError(f"flushing failed: the connection is closed") + cdef int rv = libpq.PQflush(self._pgconn_ptr) + if rv < 0: + raise e.OperationalError(f"flushing failed: {error_message(self)}") + return rv + + def set_single_row_mode(self) -> None: + if not libpq.PQsetSingleRowMode(self._pgconn_ptr): + raise e.OperationalError("setting single row mode failed") + + def get_cancel(self) -> PGcancel: + cdef libpq.PGcancel *ptr = libpq.PQgetCancel(self._pgconn_ptr) + if not ptr: + raise e.OperationalError("couldn't create cancel object") + return PGcancel._from_ptr(ptr) + + cpdef object notifies(self): + cdef libpq.PGnotify *ptr + with nogil: + ptr = libpq.PQnotifies(self._pgconn_ptr) + if ptr: + ret = PGnotify(ptr.relname, ptr.be_pid, ptr.extra) + libpq.PQfreemem(ptr) + return ret + else: + return None + + def put_copy_data(self, buffer) -> int: + cdef int rv + cdef char *cbuffer + cdef Py_ssize_t length + + _buffer_as_string_and_size(buffer, &cbuffer, &length) + rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, <int>length) + if rv < 0: + raise e.OperationalError(f"sending copy data failed: {error_message(self)}") + return rv + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + cdef int rv + cdef const char *cerr = NULL + if error is not None: + cerr = PyBytes_AsString(error) + rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr) + if rv < 0: + raise e.OperationalError(f"sending copy end failed: {error_message(self)}") + return rv + + def get_copy_data(self, int async_) -> Tuple[int, memoryview]: + cdef char *buffer_ptr = NULL + cdef int nbytes + nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_) + if nbytes == -2: + raise e.OperationalError(f"receiving copy data failed: {error_message(self)}") + if buffer_ptr is not NULL: + data = PyMemoryView_FromObject( + PQBuffer._from_buffer(<unsigned char *>buffer_ptr, nbytes)) + return nbytes, data + else: + return nbytes, b"" # won't parse it, doesn't really be memoryview + + def trace(self, fileno: int) -> None: + if sys.platform != "linux": + raise e.NotSupportedError("currently only supported on Linux") + stream = fdopen(fileno, b"w") + libpq.PQtrace(self._pgconn_ptr, stream) + + def set_trace_flags(self, flags: Trace) -> None: + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQsetTraceFlags requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + libpq.PQsetTraceFlags(self._pgconn_ptr, flags) + + def untrace(self) -> None: + libpq.PQuntrace(self._pgconn_ptr) + + def encrypt_password( + self, const char *passwd, const char *user, algorithm = None + ) -> bytes: + if libpq.PG_VERSION_NUM < 100000: + raise e.NotSupportedError( + f"PQencryptPasswordConn requires libpq from PostgreSQL 10," + f" {libpq.PG_VERSION_NUM} available instead" + ) + + cdef char *out + cdef const char *calgo = NULL + if algorithm: + calgo = algorithm + out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo) + if not out: + raise e.OperationalError( + f"password encryption failed: {error_message(self)}" + ) + + rv = bytes(out) + libpq.PQfreemem(out) + return rv + + def make_empty_result(self, int exec_status) -> PGresult: + cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult( + self._pgconn_ptr, <libpq.ExecStatusType>exec_status) + if not rv: + raise MemoryError("couldn't allocate empty PGresult") + return PGresult._from_ptr(rv) + + @property + def pipeline_status(self) -> int: + """The current pipeline mode status. + + For libpq < 14.0, always return 0 (PQ_PIPELINE_OFF). + """ + if libpq.PG_VERSION_NUM < 140000: + return libpq.PQ_PIPELINE_OFF + cdef int status = libpq.PQpipelineStatus(self._pgconn_ptr) + return status + + def enter_pipeline_mode(self) -> None: + """Enter pipeline mode. + + :raises ~e.OperationalError: in case of failure to enter the pipeline + mode. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQenterPipelineMode requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + if libpq.PQenterPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError("failed to enter pipeline mode") + + def exit_pipeline_mode(self) -> None: + """Exit pipeline mode. + + :raises ~e.OperationalError: in case of failure to exit the pipeline + mode. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQexitPipelineMode requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + if libpq.PQexitPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError(error_message(self)) + + def pipeline_sync(self) -> None: + """Mark a synchronization point in a pipeline. + + :raises ~e.OperationalError: if the connection is not in pipeline mode + or if sync failed. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQpipelineSync requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + rv = libpq.PQpipelineSync(self._pgconn_ptr) + if rv == 0: + raise e.OperationalError("connection not in pipeline mode") + if rv != 1: + raise e.OperationalError("failed to sync pipeline") + + def send_flush_request(self) -> None: + """Sends a request for the server to flush its output buffer. + + :raises ~e.OperationalError: if the flush request failed. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQsendFlushRequest requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + cdef int rv = libpq.PQsendFlushRequest(self._pgconn_ptr) + if rv == 0: + raise e.OperationalError(f"flush request failed: {error_message(self)}") + + +cdef int _ensure_pgconn(PGconn pgconn) except 0: + if pgconn._pgconn_ptr is not NULL: + return 1 + + raise e.OperationalError("the connection is closed") + + +cdef char *_call_bytes(PGconn pgconn, conn_bytes_f func) except NULL: + """ + Call one of the pgconn libpq functions returning a bytes pointer. + """ + if not _ensure_pgconn(pgconn): + return NULL + cdef char *rv = func(pgconn._pgconn_ptr) + assert rv is not NULL + return rv + + +cdef int _call_int(PGconn pgconn, conn_int_f func) except -2: + """ + Call one of the pgconn libpq functions returning an int. + """ + if not _ensure_pgconn(pgconn): + return -2 + return func(pgconn._pgconn_ptr) + + +cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) with gil: + cdef PGconn pgconn = <object>arg + if pgconn.notice_handler is None: + return + + cdef PGresult res = PGresult._from_ptr(<libpq.PGresult *>res_ptr) + try: + pgconn.notice_handler(res) + except Exception as e: + logger.exception("error in notice receiver: %s", e) + finally: + res._pgresult_ptr = NULL # avoid destroying the pgresult_ptr + + +cdef (Py_ssize_t, libpq.Oid *, char * const*, int *, int *) _query_params_args( + list param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]], + list param_formats: Optional[Sequence[int]], +) except *: + cdef int i + + # the PostgresQuery converts the param_types to tuple, so this operation + # is most often no-op + cdef tuple tparam_types + if param_types is not None and not isinstance(param_types, tuple): + tparam_types = tuple(param_types) + else: + tparam_types = param_types + + cdef Py_ssize_t nparams = len(param_values) if param_values else 0 + if tparam_types is not None and len(tparam_types) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(tparam_types)) + ) + if param_formats is not None and len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_formats" + % (nparams, len(param_formats)) + ) + + cdef char **aparams = NULL + cdef int *alenghts = NULL + cdef char *ptr + cdef Py_ssize_t length + + if nparams: + aparams = <char **>PyMem_Malloc(nparams * sizeof(char *)) + alenghts = <int *>PyMem_Malloc(nparams * sizeof(int)) + for i in range(nparams): + obj = param_values[i] + if obj is None: + aparams[i] = NULL + alenghts[i] = 0 + else: + # TODO: it is a leak if this fails (but it should only fail + # on internal error, e.g. if obj is not a buffer) + _buffer_as_string_and_size(obj, &ptr, &length) + aparams[i] = ptr + alenghts[i] = <int>length + + cdef libpq.Oid *atypes = NULL + if tparam_types: + atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid)) + for i in range(nparams): + atypes[i] = tparam_types[i] + + cdef int *aformats = NULL + if param_formats is not None: + aformats = <int *>PyMem_Malloc(nparams * sizeof(int *)) + for i in range(nparams): + aformats[i] = param_formats[i] + + return (nparams, atypes, aparams, alenghts, aformats) + + +cdef void _clear_query_params( + libpq.Oid *ctypes, char *const *cvalues, int *clenghst, int *cformats +): + PyMem_Free(ctypes) + PyMem_Free(<char **>cvalues) + PyMem_Free(clenghst) + PyMem_Free(cformats) diff --git a/psycopg_c/psycopg_c/pq/pgresult.pyx b/psycopg_c/psycopg_c/pq/pgresult.pyx new file mode 100644 index 0000000..6df42e8 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pgresult.pyx @@ -0,0 +1,157 @@ +""" +psycopg_c.pq.PGresult object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython +from cpython.mem cimport PyMem_Malloc, PyMem_Free + +from psycopg.pq.misc import PGresAttDesc +from psycopg.pq._enums import ExecStatus + + +@cython.freelist(8) +cdef class PGresult: + def __cinit__(self): + self._pgresult_ptr = NULL + + @staticmethod + cdef PGresult _from_ptr(libpq.PGresult *ptr): + cdef PGresult rv = PGresult.__new__(PGresult) + rv._pgresult_ptr = ptr + return rv + + def __dealloc__(self) -> None: + self.clear() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + status = ExecStatus(self.status) + return f"<{cls} [{status.name}] at 0x{id(self):x}>" + + def clear(self) -> None: + if self._pgresult_ptr is not NULL: + libpq.PQclear(self._pgresult_ptr) + self._pgresult_ptr = NULL + + @property + def pgresult_ptr(self) -> Optional[int]: + if self._pgresult_ptr: + return <long long><void *>self._pgresult_ptr + else: + return None + + @property + def status(self) -> int: + return libpq.PQresultStatus(self._pgresult_ptr) + + @property + def error_message(self) -> bytes: + return libpq.PQresultErrorMessage(self._pgresult_ptr) + + def error_field(self, int fieldcode) -> Optional[bytes]: + cdef char * rv = libpq.PQresultErrorField(self._pgresult_ptr, fieldcode) + if rv is not NULL: + return rv + else: + return None + + @property + def ntuples(self) -> int: + return libpq.PQntuples(self._pgresult_ptr) + + @property + def nfields(self) -> int: + return libpq.PQnfields(self._pgresult_ptr) + + def fname(self, int column_number) -> Optional[bytes]: + cdef char *rv = libpq.PQfname(self._pgresult_ptr, column_number) + if rv is not NULL: + return rv + else: + return None + + def ftable(self, int column_number) -> int: + return libpq.PQftable(self._pgresult_ptr, column_number) + + def ftablecol(self, int column_number) -> int: + return libpq.PQftablecol(self._pgresult_ptr, column_number) + + def fformat(self, int column_number) -> int: + return libpq.PQfformat(self._pgresult_ptr, column_number) + + def ftype(self, int column_number) -> int: + return libpq.PQftype(self._pgresult_ptr, column_number) + + def fmod(self, int column_number) -> int: + return libpq.PQfmod(self._pgresult_ptr, column_number) + + def fsize(self, int column_number) -> int: + return libpq.PQfsize(self._pgresult_ptr, column_number) + + @property + def binary_tuples(self) -> int: + return libpq.PQbinaryTuples(self._pgresult_ptr) + + def get_value(self, int row_number, int column_number) -> Optional[bytes]: + cdef int crow = row_number + cdef int ccol = column_number + cdef int length = libpq.PQgetlength(self._pgresult_ptr, crow, ccol) + cdef char *v + if length: + v = libpq.PQgetvalue(self._pgresult_ptr, crow, ccol) + # TODO: avoid copy + return v[:length] + else: + if libpq.PQgetisnull(self._pgresult_ptr, crow, ccol): + return None + else: + return b"" + + @property + def nparams(self) -> int: + return libpq.PQnparams(self._pgresult_ptr) + + def param_type(self, int param_number) -> int: + return libpq.PQparamtype(self._pgresult_ptr, param_number) + + @property + def command_status(self) -> Optional[bytes]: + cdef char *rv = libpq.PQcmdStatus(self._pgresult_ptr) + if rv is not NULL: + return rv + else: + return None + + @property + def command_tuples(self) -> Optional[int]: + cdef char *rv = libpq.PQcmdTuples(self._pgresult_ptr) + if rv is NULL: + return None + cdef bytes brv = rv + return int(brv) if brv else None + + @property + def oid_value(self) -> int: + return libpq.PQoidValue(self._pgresult_ptr) + + def set_attributes(self, descriptions: List[PGresAttDesc]): + cdef Py_ssize_t num = len(descriptions) + cdef libpq.PGresAttDesc *attrs = <libpq.PGresAttDesc *>PyMem_Malloc( + num * sizeof(libpq.PGresAttDesc)) + + for i in range(num): + descr = descriptions[i] + attrs[i].name = descr.name + attrs[i].tableid = descr.tableid + attrs[i].columnid = descr.columnid + attrs[i].format = descr.format + attrs[i].typid = descr.typid + attrs[i].typlen = descr.typlen + attrs[i].atttypmod = descr.atttypmod + + cdef int res = libpq.PQsetResultAttrs(self._pgresult_ptr, <int>num, attrs) + PyMem_Free(attrs) + if (res == 0): + raise e.OperationalError("PQsetResultAttrs failed") diff --git a/psycopg_c/psycopg_c/pq/pqbuffer.pyx b/psycopg_c/psycopg_c/pq/pqbuffer.pyx new file mode 100644 index 0000000..eb5d648 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pqbuffer.pyx @@ -0,0 +1,111 @@ +""" +PQbuffer object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython +from cpython.bytes cimport PyBytes_AsStringAndSize +from cpython.buffer cimport PyObject_CheckBuffer, PyBUF_SIMPLE +from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release + + +@cython.freelist(32) +cdef class PQBuffer: + """ + Wrap a chunk of memory allocated by the libpq and expose it as memoryview. + """ + @staticmethod + cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length): + cdef PQBuffer rv = PQBuffer.__new__(PQBuffer) + rv.buf = buf + rv.len = length + return rv + + def __cinit__(self): + self.buf = NULL + self.len = 0 + + def __dealloc__(self): + if self.buf: + libpq.PQfreemem(self.buf) + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__qualname__}" + f"({bytes(self)})" + ) + + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = self.buf + buffer.obj = self + buffer.len = self.len + buffer.itemsize = sizeof(unsigned char) + buffer.readonly = 1 + buffer.ndim = 1 + buffer.format = NULL # unsigned char + buffer.shape = &self.len + buffer.strides = NULL + buffer.suboffsets = NULL + buffer.internal = NULL + + def __releasebuffer__(self, Py_buffer *buffer): + pass + + +@cython.freelist(32) +cdef class ViewBuffer: + """ + Wrap a chunk of memory owned by a different object. + """ + @staticmethod + cdef ViewBuffer _from_buffer( + object obj, unsigned char *buf, Py_ssize_t length + ): + cdef ViewBuffer rv = ViewBuffer.__new__(ViewBuffer) + rv.obj = obj + rv.buf = buf + rv.len = length + return rv + + def __cinit__(self): + self.buf = NULL + self.len = 0 + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__qualname__}" + f"({bytes(self)})" + ) + + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = self.buf + buffer.obj = self + buffer.len = self.len + buffer.itemsize = sizeof(unsigned char) + buffer.readonly = 1 + buffer.ndim = 1 + buffer.format = NULL # unsigned char + buffer.shape = &self.len + buffer.strides = NULL + buffer.suboffsets = NULL + buffer.internal = NULL + + def __releasebuffer__(self, Py_buffer *buffer): + pass + + +cdef int _buffer_as_string_and_size( + data: "Buffer", char **ptr, Py_ssize_t *length +) except -1: + cdef Py_buffer buf + + if isinstance(data, bytes): + PyBytes_AsStringAndSize(data, ptr, length) + elif PyObject_CheckBuffer(data): + PyObject_GetBuffer(data, &buf, PyBUF_SIMPLE) + ptr[0] = <char *>buf.buf + length[0] = buf.len + PyBuffer_Release(&buf) + else: + raise TypeError(f"bytes or buffer expected, got {type(data)}") diff --git a/psycopg_c/psycopg_c/py.typed b/psycopg_c/psycopg_c/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/psycopg_c/psycopg_c/py.typed diff --git a/psycopg_c/psycopg_c/types/array.pyx b/psycopg_c/psycopg_c/types/array.pyx new file mode 100644 index 0000000..9abaef9 --- /dev/null +++ b/psycopg_c/psycopg_c/types/array.pyx @@ -0,0 +1,276 @@ +""" +C optimized functions to manipulate arrays +""" + +# Copyright (C) 2022 The Psycopg Team + +import cython + +from libc.stdint cimport int32_t, uint32_t +from libc.string cimport memset, strchr +from cpython.mem cimport PyMem_Realloc, PyMem_Free +from cpython.ref cimport Py_INCREF +from cpython.list cimport PyList_New,PyList_Append, PyList_GetSlice +from cpython.list cimport PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE +from cpython.object cimport PyObject + +from psycopg_c.pq cimport _buffer_as_string_and_size +from psycopg_c.pq.libpq cimport Oid +from psycopg_c._psycopg cimport endian + +from psycopg import errors as e + +cdef extern from *: + """ +/* Defined in PostgreSQL in src/include/utils/array.h */ +#define MAXDIM 6 + """ + const int MAXDIM + + +cdef class ArrayLoader(_CRecursiveLoader): + + format = PQ_TEXT + base_oid = 0 + delimiter = b"," + + cdef PyObject *row_loader + cdef char cdelim + + # A memory area which used to unescape elements. + # Keep it here to avoid a malloc per element and to set up exceptions + # to make sure to free it on error. + cdef char *scratch + cdef size_t sclen + + cdef object cload(self, const char *data, size_t length): + if self.cdelim == b"\x00": + self.row_loader = self._tx._c_get_loader( + <PyObject *>self.base_oid, <PyObject *>PQ_TEXT) + self.cdelim = self.delimiter[0] + + return _array_load_text( + data, length, self.row_loader, self.cdelim, + &(self.scratch), &(self.sclen)) + + def __dealloc__(self): + PyMem_Free(self.scratch) + + +@cython.final +cdef class ArrayBinaryLoader(_CRecursiveLoader): + + format = PQ_BINARY + + cdef PyObject *row_loader + + cdef object cload(self, const char *data, size_t length): + rv = _array_load_binary(data, length, self._tx, &(self.row_loader)) + return rv + + +cdef object _array_load_text( + const char *buf, size_t length, PyObject *row_loader, char cdelim, + char **scratch, size_t *sclen +): + if length == 0: + raise e.DataError("malformed array: empty data") + + cdef const char *end = buf + length + + # Remove the dimensions information prefix (``[...]=``) + if buf[0] == b"[": + buf = strchr(buf + 1, b'=') + if buf == NULL: + raise e.DataError("malformed array: no '=' after dimension information") + buf += 1 + + # TODO: further optimization: pre-scan the array to find the array + # dimensions, so that we can preallocate the list sized instead of calling + # append, which is the dominating operation + + cdef list stack = [] + cdef list a = [] + rv = a + cdef PyObject *tmp + + cdef CLoader cloader = None + cdef object pyload = None + if (<RowLoader>row_loader).cloader is not None: + cloader = (<RowLoader>row_loader).cloader + else: + pyload = (<RowLoader>row_loader).loadfunc + + while buf < end: + if buf[0] == b'{': + if stack: + tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1) + PyList_Append(<object>tmp, a) + PyList_Append(stack, a) + a = [] + buf += 1 + + elif buf[0] == b'}': + if not stack: + raise e.DataError("malformed array: unexpected '}'") + rv = stack.pop() + buf += 1 + + elif buf[0] == cdelim: + buf += 1 + + else: + v = _parse_token( + &buf, end, cdelim, scratch, sclen, cloader, pyload) + if not stack: + raise e.DataError("malformed array: missing initial '{'") + tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1) + PyList_Append(<object>tmp, v) + + return rv + + +cdef object _parse_token( + const char **bufptr, const char *bufend, char cdelim, + char **scratch, size_t *sclen, CLoader cloader, object load +): + cdef const char *start = bufptr[0] + cdef int has_quotes = start[0] == b'"' + cdef int quoted = has_quotes + cdef int num_escapes = 0 + cdef int escaped = 0 + + if has_quotes: + start += 1 + cdef const char *end = start + + while end < bufend: + if (end[0] == cdelim or end[0] == b'}') and not quoted: + break + elif end[0] == b'\\' and not escaped: + num_escapes += 1 + escaped = 1 + end += 1 + continue + elif end[0] == b'"' and not escaped: + quoted = 0 + escaped = 0 + end += 1 + else: + raise e.DataError("malformed array: hit the end of the buffer") + + # Return the new position for the buffer + bufptr[0] = end + if has_quotes: + end -= 1 + + cdef int length = (end - start) + if length == 4 and not has_quotes \ + and start[0] == b'N' and start[1] == b'U' \ + and start[2] == b'L' and start[3] == b'L': + return None + + cdef const char *src + cdef char *tgt + cdef size_t unesclen + + if not num_escapes: + if cloader is not None: + return cloader.cload(start, length) + else: + b = start[:length] + return load(b) + + else: + unesclen = length - num_escapes + 1 + if unesclen > sclen[0]: + scratch[0] = <char *>PyMem_Realloc(scratch[0], unesclen) + sclen[0] = unesclen + + src = start + tgt = scratch[0] + while src < end: + if src[0] == b'\\': + src += 1 + tgt[0] = src[0] + src += 1 + tgt += 1 + + tgt[0] = b'\x00' + + if cloader is not None: + return cloader.cload(scratch[0], length - num_escapes) + else: + b = scratch[0][:length - num_escapes] + return load(b) + + +@cython.cdivision(True) +cdef object _array_load_binary( + const char *buf, size_t length, Transformer tx, PyObject **row_loader_ptr +): + # head is ndims, hasnull, elem oid + cdef uint32_t *buf32 = <uint32_t *>buf + cdef int ndims = endian.be32toh(buf32[0]) + + if ndims <= 0: + return [] + elif ndims > MAXDIM: + raise e.DataError( + r"unexpected number of dimensions %s exceeding the maximum allowed %s" + % (ndims, MAXDIM) + ) + + cdef object oid + if row_loader_ptr[0] == NULL: + oid = <Oid>endian.be32toh(buf32[2]) + row_loader_ptr[0] = tx._c_get_loader(<PyObject *>oid, <PyObject *>PQ_BINARY) + + cdef Py_ssize_t[MAXDIM] dims + cdef int i + for i in range(ndims): + # Every dimension is dim, lower bound + dims[i] = endian.be32toh(buf32[3 + 2 * i]) + + buf += (3 + 2 * ndims) * sizeof(uint32_t) + out = _array_load_binary_rec(ndims, dims, &buf, row_loader_ptr[0]) + return out + + +cdef object _array_load_binary_rec( + Py_ssize_t ndims, Py_ssize_t *dims, const char **bufptr, PyObject *row_loader +): + cdef const char *buf + cdef int i + cdef int32_t size + cdef object val + + cdef Py_ssize_t nelems = dims[0] + cdef list out = PyList_New(nelems) + + if ndims == 1: + buf = bufptr[0] + for i in range(nelems): + size = <int32_t>endian.be32toh((<uint32_t *>buf)[0]) + buf += sizeof(uint32_t) + if size == -1: + val = None + else: + if (<RowLoader>row_loader).cloader is not None: + val = (<RowLoader>row_loader).cloader.cload(buf, size) + else: + val = (<RowLoader>row_loader).loadfunc(buf[:size]) + buf += size + + Py_INCREF(val) + PyList_SET_ITEM(out, i, val) + + bufptr[0] = buf + + else: + for i in range(nelems): + val = _array_load_binary_rec(ndims - 1, dims + 1, bufptr, row_loader) + Py_INCREF(val) + PyList_SET_ITEM(out, i, val) + + return out diff --git a/psycopg_c/psycopg_c/types/bool.pyx b/psycopg_c/psycopg_c/types/bool.pyx new file mode 100644 index 0000000..86cf88e --- /dev/null +++ b/psycopg_c/psycopg_c/types/bool.pyx @@ -0,0 +1,78 @@ +""" +Cython adapters for boolean. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython + + +@cython.final +cdef class BoolDumper(CDumper): + + format = PQ_TEXT + oid = oids.BOOL_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *buf = CDumper.ensure_size(rv, offset, 1) + + # Fast paths, just a pointer comparison + if obj is True: + buf[0] = b"t" + elif obj is False: + buf[0] = b"f" + elif obj: + buf[0] = b"t" + else: + buf[0] = b"f" + + return 1 + + def quote(self, obj: bool) -> bytes: + if obj is True: + return b"true" + elif obj is False: + return b"false" + else: + return b"true" if obj else b"false" + + +@cython.final +cdef class BoolBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.BOOL_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *buf = CDumper.ensure_size(rv, offset, 1) + + # Fast paths, just a pointer comparison + if obj is True: + buf[0] = b"\x01" + elif obj is False: + buf[0] = b"\x00" + elif obj: + buf[0] = b"\x01" + else: + buf[0] = b"\x00" + + return 1 + + +@cython.final +cdef class BoolLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + # this creates better C than `return data[0] == b't'` + return True if data[0] == b't' else False + + +@cython.final +cdef class BoolBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return True if data[0] else False diff --git a/psycopg_c/psycopg_c/types/datetime.pyx b/psycopg_c/psycopg_c/types/datetime.pyx new file mode 100644 index 0000000..51e7dcf --- /dev/null +++ b/psycopg_c/psycopg_c/types/datetime.pyx @@ -0,0 +1,1136 @@ +""" +Cython adapters for date/time types. +""" + +# Copyright (C) 2021 The Psycopg Team + +from libc.string cimport memset, strchr +from cpython cimport datetime as cdt +from cpython.dict cimport PyDict_GetItem +from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs + +cdef extern from "Python.h": + const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL + object PyTimeZone_FromOffset(object offset) + +cdef extern from *: + """ +/* Multipliers from fraction of seconds to microseconds */ +static int _uspad[] = {0, 100000, 10000, 1000, 100, 10, 1}; + """ + cdef int *_uspad + +from datetime import date, time, timedelta, datetime, timezone + +from psycopg_c._psycopg cimport endian + +from psycopg import errors as e +from psycopg._compat import ZoneInfo + + +# Initialise the datetime C API +cdt.import_datetime() + +cdef enum: + ORDER_YMD = 0 + ORDER_DMY = 1 + ORDER_MDY = 2 + ORDER_PGDM = 3 + ORDER_PGMD = 4 + +cdef enum: + INTERVALSTYLE_OTHERS = 0 + INTERVALSTYLE_SQL_STANDARD = 1 + INTERVALSTYLE_POSTGRES = 2 + +cdef enum: + PG_DATE_EPOCH_DAYS = 730120 # date(2000, 1, 1).toordinal() + PY_DATE_MIN_DAYS = 1 # date.min.toordinal() + +cdef object date_toordinal = date.toordinal +cdef object date_fromordinal = date.fromordinal +cdef object datetime_astimezone = datetime.astimezone +cdef object time_utcoffset = time.utcoffset +cdef object timedelta_total_seconds = timedelta.total_seconds +cdef object timezone_utc = timezone.utc +cdef object pg_datetime_epoch = datetime(2000, 1, 1) +cdef object pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc) + +cdef object _month_abbr = { + n: i + for i, n in enumerate( + b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1 + ) +} + + +@cython.final +cdef class DateDumper(CDumper): + + format = PQ_TEXT + oid = oids.DATE_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + cdef str s = str(obj) + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class DateBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.DATE_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int32_t days = PyObject_CallFunctionObjArgs( + date_toordinal, <PyObject *>obj, NULL) + days -= PG_DATE_EPOCH_DAYS + cdef int32_t *buf = <int32_t *>CDumper.ensure_size( + rv, offset, sizeof(int32_t)) + buf[0] = endian.htobe32(days) + return sizeof(int32_t) + + +cdef class _BaseTimeDumper(CDumper): + + cpdef get_key(self, obj, format): + # Use (cls,) to report the need to upgrade to a dumper for timetz (the + # Frankenstein of the data types). + if not obj.tzinfo: + return self.cls + else: + return (self.cls,) + + cpdef upgrade(self, obj: time, format): + raise NotImplementedError + + +cdef class _BaseTimeTextDumper(_BaseTimeDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + cdef str s = str(obj) + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class TimeDumper(_BaseTimeTextDumper): + + oid = oids.TIME_OID + + cpdef upgrade(self, obj, format): + if not obj.tzinfo: + return self + else: + return TimeTzDumper(self.cls) + + +@cython.final +cdef class TimeTzDumper(_BaseTimeTextDumper): + + oid = oids.TIMETZ_OID + + +@cython.final +cdef class TimeBinaryDumper(_BaseTimeDumper): + + format = PQ_BINARY + oid = oids.TIME_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t micros = cdt.time_microsecond(obj) + 1000000 * ( + cdt.time_second(obj) + + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj)) + ) + + cdef int64_t *buf = <int64_t *>CDumper.ensure_size( + rv, offset, sizeof(int64_t)) + buf[0] = endian.htobe64(micros) + return sizeof(int64_t) + + cpdef upgrade(self, obj, format): + if not obj.tzinfo: + return self + else: + return TimeTzBinaryDumper(self.cls) + + +@cython.final +cdef class TimeTzBinaryDumper(_BaseTimeDumper): + + format = PQ_BINARY + oid = oids.TIMETZ_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t micros = cdt.time_microsecond(obj) + 1_000_000 * ( + cdt.time_second(obj) + + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj)) + ) + + off = PyObject_CallFunctionObjArgs(time_utcoffset, <PyObject *>obj, NULL) + cdef int32_t offsec = int(PyObject_CallFunctionObjArgs( + timedelta_total_seconds, <PyObject *>off, NULL)) + + cdef char *buf = CDumper.ensure_size( + rv, offset, sizeof(int64_t) + sizeof(int32_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(-offsec) + + return sizeof(int64_t) + sizeof(int32_t) + + +cdef class _BaseDatetimeDumper(CDumper): + + cpdef get_key(self, obj, format): + # Use (cls,) to report the need to upgrade (downgrade, actually) to a + # dumper for naive timestamp. + if obj.tzinfo: + return self.cls + else: + return (self.cls,) + + cpdef upgrade(self, obj: time, format): + raise NotImplementedError + + +cdef class _BaseDatetimeTextDumper(_BaseDatetimeDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + cdef str s = str(obj) + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class DatetimeDumper(_BaseDatetimeTextDumper): + + oid = oids.TIMESTAMPTZ_OID + + cpdef upgrade(self, obj, format): + if obj.tzinfo: + return self + else: + return DatetimeNoTzDumper(self.cls) + + +@cython.final +cdef class DatetimeNoTzDumper(_BaseDatetimeTextDumper): + + oid = oids.TIMESTAMP_OID + + +@cython.final +cdef class DatetimeBinaryDumper(_BaseDatetimeDumper): + + format = PQ_BINARY + oid = oids.TIMESTAMPTZ_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + delta = obj - pg_datetimetz_epoch + + cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * ( + 86_400 * <int64_t>cdt.timedelta_days(delta) + + <int64_t>cdt.timedelta_seconds(delta)) + + cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + return sizeof(int64_t) + + cpdef upgrade(self, obj, format): + if obj.tzinfo: + return self + else: + return DatetimeNoTzBinaryDumper(self.cls) + + +@cython.final +cdef class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper): + + format = PQ_BINARY + oid = oids.TIMESTAMP_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + delta = obj - pg_datetime_epoch + + cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * ( + 86_400 * <int64_t>cdt.timedelta_days(delta) + + <int64_t>cdt.timedelta_seconds(delta)) + + cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + return sizeof(int64_t) + + +@cython.final +cdef class TimedeltaDumper(CDumper): + + format = PQ_TEXT + oid = oids.INTERVAL_OID + cdef int _style + + def __cinit__(self, cls, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_intervalstyle(self._pgconn) + if ds[0] == b's': # sql_standard + self._style = INTERVALSTYLE_SQL_STANDARD + else: # iso_8601, postgres, postgres_verbose + self._style = INTERVALSTYLE_OTHERS + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + cdef str s + if self._style == INTERVALSTYLE_OTHERS: + # The comma is parsed ok by PostgreSQL but it's not documented + # and it seems brittle to rely on it. CRDB doesn't consume it well. + s = str(obj).replace(",", "") + else: + # sql_standard format needs explicit signs + # otherwise -1 day 1 sec will mean -1 sec + s = "%+d day %+d second %+d microsecond" % ( + obj.days, obj.seconds, obj.microseconds) + + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class TimedeltaBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INTERVAL_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t micros = ( + 1_000_000 * <int64_t>cdt.timedelta_seconds(obj) + + cdt.timedelta_microseconds(obj)) + cdef int32_t days = cdt.timedelta_days(obj) + + cdef char *buf = CDumper.ensure_size( + rv, offset, sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(days) + (<int32_t *>(buf + sizeof(int64_t) + sizeof(int32_t)))[0] = 0 + + return sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t) + + +@cython.final +cdef class DateLoader(CLoader): + + format = PQ_TEXT + cdef int _order + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_datestyle(self._pgconn) + if ds[0] == b'I': # ISO + self._order = ORDER_YMD + elif ds[0] == b'G': # German + self._order = ORDER_DMY + elif ds[0] == b'S': # SQL, DMY / MDY + self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY + elif ds[0] == b'P': # Postgres, DMY / MDY + self._order = ORDER_DMY if ds[10] == b'D' else ORDER_MDY + else: + raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + cdef object _error_date(self, const char *data, str msg): + s = bytes(data).decode("utf8", "replace") + if s == "infinity" or len(s.split()[0]) > 10: + raise e.DataError(f"date too large (after year 10K): {s!r}") from None + elif s == "-infinity" or "BC" in s: + raise e.DataError(f"date too small (before year 1): {s!r}") from None + else: + raise e.DataError(f"can't parse date {s!r}: {msg}") from None + + cdef object cload(self, const char *data, size_t length): + if length != 10: + self._error_date(data, "unexpected length") + + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + + cdef const char *ptr + cdef const char *end = data + length + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse date {s!r}") + + try: + if self._order == ORDER_YMD: + return cdt.date_new(vals[0], vals[1], vals[2]) + elif self._order == ORDER_DMY: + return cdt.date_new(vals[2], vals[1], vals[0]) + else: + return cdt.date_new(vals[2], vals[0], vals[1]) + except ValueError as ex: + self._error_date(data, str(ex)) + + +@cython.final +cdef class DateBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int days = endian.be32toh((<uint32_t *>data)[0]) + cdef object pydays = days + PG_DATE_EPOCH_DAYS + try: + return PyObject_CallFunctionObjArgs( + date_fromordinal, <PyObject *>pydays, NULL) + except ValueError: + if days < PY_DATE_MIN_DAYS: + raise e.DataError("date too small (before year 1)") from None + else: + raise e.DataError("date too large (after year 10K)") from None + + +@cython.final +cdef class TimeLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + cdef const char *end = data + length + + # Parse the first 3 groups of digits + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse time {s!r}") + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + try: + return cdt.time_new(vals[0], vals[1], vals[2], us, None) + except ValueError as ex: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse time {s!r}: {ex}") from None + + +@cython.final +cdef class TimeBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int h, m, s, us + + with cython.cdivision(True): + us = val % 1_000_000 + val //= 1_000_000 + + s = val % 60 + val //= 60 + + m = val % 60 + h = <int>(val // 60) + + try: + return cdt.time_new(h, m, s, us, None) + except ValueError: + raise e.DataError( + f"time not supported by Python: hour={h}" + ) from None + + +@cython.final +cdef class TimetzLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + cdef const char *end = data + length + + # Parse the first 3 groups of digits (time) + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse timetz {s!r}") + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Parse the timezone + cdef int offsecs = _parse_timezone_to_seconds(&ptr, end) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse timetz {s!r}") + + tz = _timezone_from_seconds(offsecs) + try: + return cdt.time_new(vals[0], vals[1], vals[2], us, tz) + except ValueError as ex: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse timetz {s!r}: {ex}") from None + + +@cython.final +cdef class TimetzBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int32_t off = endian.be32toh((<uint32_t *>(data + sizeof(int64_t)))[0]) + cdef int h, m, s, us + + with cython.cdivision(True): + us = val % 1_000_000 + val //= 1_000_000 + + s = val % 60 + val //= 60 + + m = val % 60 + h = <int>(val // 60) + + tz = _timezone_from_seconds(-off) + try: + return cdt.time_new(h, m, s, us, tz) + except ValueError: + raise e.DataError( + f"time not supported by Python: hour={h}" + ) from None + + +@cython.final +cdef class TimestampLoader(CLoader): + + format = PQ_TEXT + cdef int _order + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_datestyle(self._pgconn) + if ds[0] == b'I': # ISO + self._order = ORDER_YMD + elif ds[0] == b'G': # German + self._order = ORDER_DMY + elif ds[0] == b'S': # SQL, DMY / MDY + self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY + elif ds[0] == b'P': # Postgres, DMY / MDY + self._order = ORDER_PGDM if ds[10] == b'D' else ORDER_PGMD + else: + raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + cdef object cload(self, const char *data, size_t length): + cdef const char *end = data + length + if end[-1] == b'C': # ends with BC + raise _get_timestamp_load_error(self._pgconn, data) from None + + if self._order == ORDER_PGDM or self._order == ORDER_PGMD: + return self._cload_pg(data, end) + + cdef int vals[6] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + + # Parse the first 6 groups of digits (date and time) + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Resolve the YMD order + cdef int y, m, d + if self._order == ORDER_YMD: + y, m, d = vals[0], vals[1], vals[2] + elif self._order == ORDER_DMY: + d, m, y = vals[0], vals[1], vals[2] + else: # self._order == ORDER_MDY + m, d, y = vals[0], vals[1], vals[2] + + try: + return cdt.datetime_new( + y, m, d, vals[3], vals[4], vals[5], us, None) + except ValueError as ex: + raise _get_timestamp_load_error(self._pgconn, data, ex) from None + + cdef object _cload_pg(self, const char *data, const char *end): + cdef int vals[4] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + + # Find Wed Jun 02 or Wed 02 Jun + cdef char *seps[3] + seps[0] = strchr(data, b' ') + seps[1] = strchr(seps[0] + 1, b' ') if seps[0] != NULL else NULL + seps[2] = strchr(seps[1] + 1, b' ') if seps[1] != NULL else NULL + if seps[2] == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the following 3 groups of digits (time) + ptr = _parse_date_values(seps[2] + 1, end, vals, 3) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Parse the year + ptr = _parse_date_values(ptr + 1, end, vals + 3, 1) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Resolve the MD order + cdef int m, d + try: + if self._order == ORDER_PGDM: + d = int(seps[0][1 : seps[1] - seps[0]]) + m = _month_abbr[seps[1][1 : seps[2] - seps[1]]] + else: # self._order == ORDER_PGMD + m = _month_abbr[seps[0][1 : seps[1] - seps[0]]] + d = int(seps[1][1 : seps[2] - seps[1]]) + except (KeyError, ValueError) as ex: + raise _get_timestamp_load_error(self._pgconn, data, ex) from None + + try: + return cdt.datetime_new( + vals[3], m, d, vals[0], vals[1], vals[2], us, None) + except ValueError as ex: + raise _get_timestamp_load_error(self._pgconn, data, ex) from None + + +@cython.final +cdef class TimestampBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int64_t micros, secs, days + + # Work only with positive values as the cdivision behaves differently + # with negative values, and cdivision=False adds overhead. + cdef int64_t aval = val if val >= 0 else -val + + # Group the micros in biggers stuff or timedelta_new might overflow + with cython.cdivision(True): + secs = aval // 1_000_000 + micros = aval % 1_000_000 + + days = secs // 86_400 + secs %= 86_400 + + try: + delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros) + if val > 0: + return pg_datetime_epoch + delta + else: + return pg_datetime_epoch - delta + + except OverflowError: + if val <= 0: + raise e.DataError("timestamp too small (before year 1)") from None + else: + raise e.DataError("timestamp too large (after year 10K)") from None + + +cdef class _BaseTimestamptzLoader(CLoader): + cdef object _time_zone + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + self._time_zone = _timezone_from_connection(self._pgconn) + + +@cython.final +cdef class TimestamptzLoader(_BaseTimestamptzLoader): + + format = PQ_TEXT + cdef int _order + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_datestyle(self._pgconn) + if ds[0] == b'I': # ISO + self._order = ORDER_YMD + else: # Not true, but any non-YMD will do. + self._order = ORDER_DMY + + cdef object cload(self, const char *data, size_t length): + if self._order != ORDER_YMD: + return self._cload_notimpl(data, length) + + cdef const char *end = data + length + if end[-1] == b'C': # ends with BC + raise _get_timestamp_load_error(self._pgconn, data) from None + + cdef int vals[6] + memset(vals, 0, sizeof(vals)) + + # Parse the first 6 groups of digits (date and time) + cdef const char *ptr + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Resolve the YMD order + cdef int y, m, d + if self._order == ORDER_YMD: + y, m, d = vals[0], vals[1], vals[2] + elif self._order == ORDER_DMY: + d, m, y = vals[0], vals[1], vals[2] + else: # self._order == ORDER_MDY + m, d, y = vals[0], vals[1], vals[2] + + # Parse the timezone + cdef int offsecs = _parse_timezone_to_seconds(&ptr, end) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + tzoff = cdt.timedelta_new(0, offsecs, 0) + + # The return value is a datetime with the timezone of the connection + # (in order to be consistent with the binary loader, which is the only + # thing it can return). So create a temporary datetime object, in utc, + # shift it by the offset parsed from the timestamp, and then move it to + # the connection timezone. + dt = None + try: + dt = cdt.datetime_new( + y, m, d, vals[3], vals[4], vals[5], us, timezone_utc) + dt -= tzoff + return PyObject_CallFunctionObjArgs(datetime_astimezone, + <PyObject *>dt, <PyObject *>self._time_zone, NULL) + except OverflowError as ex: + # If we have created the temporary 'dt' it means that we have a + # datetime close to max, the shift pushed it past max, overflowing. + # In this case return the datetime in a fixed offset timezone. + if dt is not None: + return dt.replace(tzinfo=timezone(tzoff)) + else: + ex1 = ex + except ValueError as ex: + ex1 = ex + + raise _get_timestamp_load_error(self._pgconn, data, ex1) from None + + cdef object _cload_notimpl(self, const char *data, size_t length): + s = bytes(data)[:length].decode("utf8", "replace") + ds = _get_datestyle(self._pgconn).decode() + raise NotImplementedError( + f"can't parse timestamptz with DateStyle {ds!r}: {s!r}" + ) + + +@cython.final +cdef class TimestamptzBinaryLoader(_BaseTimestamptzLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int64_t micros, secs, days + + # Work only with positive values as the cdivision behaves differently + # with negative values, and cdivision=False adds overhead. + cdef int64_t aval = val if val >= 0 else -val + + # Group the micros in biggers stuff or timedelta_new might overflow + with cython.cdivision(True): + secs = aval // 1_000_000 + micros = aval % 1_000_000 + + days = secs // 86_400 + secs %= 86_400 + + try: + delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros) + if val > 0: + dt = pg_datetimetz_epoch + delta + else: + dt = pg_datetimetz_epoch - delta + return PyObject_CallFunctionObjArgs(datetime_astimezone, + <PyObject *>dt, <PyObject *>self._time_zone, NULL) + + except OverflowError: + # If we were asked about a timestamp which would overflow in UTC, + # but not in the desired timezone (e.g. datetime.max at Chicago + # timezone) we can still save the day by shifting the value by the + # timezone offset and then replacing the timezone. + if self._time_zone is not None: + utcoff = self._time_zone.utcoffset( + datetime.min if val < 0 else datetime.max + ) + if utcoff: + usoff = 1_000_000 * int(utcoff.total_seconds()) + try: + ts = pg_datetime_epoch + timedelta( + microseconds=val + usoff + ) + except OverflowError: + pass # will raise downstream + else: + return ts.replace(tzinfo=self._time_zone) + + if val <= 0: + raise e.DataError( + "timestamp too small (before year 1)" + ) from None + else: + raise e.DataError( + "timestamp too large (after year 10K)" + ) from None + + +@cython.final +cdef class IntervalLoader(CLoader): + + format = PQ_TEXT + cdef int _style + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_intervalstyle(self._pgconn) + if ds[0] == b'p' and ds[8] == 0: # postgres + self._style = INTERVALSTYLE_POSTGRES + else: # iso_8601, sql_standard, postgres_verbose + self._style = INTERVALSTYLE_OTHERS + + cdef object cload(self, const char *data, size_t length): + if self._style == INTERVALSTYLE_OTHERS: + return self._cload_notimpl(data, length) + + cdef int days = 0, secs = 0, us = 0 + cdef char sign + cdef int val + cdef const char *ptr = data + cdef const char *sep + cdef const char *end = ptr + length + + # If there are spaces, there is a [+|-]n [days|months|years] + while True: + if ptr[0] == b'-' or ptr[0] == b'+': + sign = ptr[0] + ptr += 1 + else: + sign = 0 + + sep = strchr(ptr, b' ') + if sep == NULL or sep > end: + break + + val = 0 + ptr = _parse_date_values(ptr, end, &val, 1) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}") + + if sign == b'-': + val = -val + + if ptr[1] == b'y': + days = 365 * val + elif ptr[1] == b'm': + days = 30 * val + elif ptr[1] == b'd': + days = val + else: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}") + + # Skip the date part word. + ptr = strchr(ptr + 1, b' ') + if ptr != NULL and ptr < end: + ptr += 1 + else: + break + + # Parse the time part. An eventual sign was already consumed in the loop + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + if ptr != NULL: + ptr = _parse_date_values(ptr, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}") + + secs = vals[2] + 60 * (vals[1] + 60 * vals[0]) + + if ptr[0] == b'.': + ptr = _parse_micros(ptr + 1, &us) + + if sign == b'-': + secs = -secs + us = -us + + try: + return cdt.timedelta_new(days, secs, us) + except OverflowError as ex: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}: {ex}") from None + + cdef object _cload_notimpl(self, const char *data, size_t length): + s = bytes(data).decode("utf8", "replace") + style = _get_intervalstyle(self._pgconn).decode() + raise NotImplementedError( + f"can't parse interval with IntervalStyle {style!r}: {s!r}" + ) + + +@cython.final +cdef class IntervalBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int32_t days = endian.be32toh( + (<uint32_t *>(data + sizeof(int64_t)))[0]) + cdef int32_t months = endian.be32toh( + (<uint32_t *>(data + sizeof(int64_t) + sizeof(int32_t)))[0]) + + cdef int years + with cython.cdivision(True): + if months > 0: + years = months // 12 + months %= 12 + days += 30 * months + 365 * years + elif months < 0: + months = -months + years = months // 12 + months %= 12 + days -= 30 * months + 365 * years + + # Work only with positive values as the cdivision behaves differently + # with negative values, and cdivision=False adds overhead. + cdef int64_t aval = val if val >= 0 else -val + cdef int us, ussecs, usdays + + # Group the micros in biggers stuff or timedelta_new might overflow + with cython.cdivision(True): + ussecs = <int>(aval // 1_000_000) + us = aval % 1_000_000 + + usdays = ussecs // 86_400 + ussecs %= 86_400 + + if val < 0: + ussecs = -ussecs + usdays = -usdays + us = -us + + try: + return cdt.timedelta_new(days + usdays, ussecs, us) + except OverflowError as ex: + raise e.DataError(f"can't parse interval: {ex}") + + +cdef const char *_parse_date_values( + const char *ptr, const char *end, int *vals, int nvals +): + """ + Parse *nvals* numeric values separated by non-numeric chars. + + Write the result in the *vals* array (assumed zeroed) starting from *start*. + + Return the pointer at the separator after the final digit. + """ + cdef int ival = 0 + while ptr < end: + if b'0' <= ptr[0] <= b'9': + vals[ival] = vals[ival] * 10 + (ptr[0] - <char>b'0') + else: + ival += 1 + if ival >= nvals: + break + + ptr += 1 + + return ptr + + +cdef const char *_parse_micros(const char *start, int *us): + """ + Parse microseconds from a string. + + Micros are assumed up to 6 digit chars separated by a non-digit. + + Return the pointer at the separator after the final digit. + """ + cdef const char *ptr = start + while ptr[0]: + if b'0' <= ptr[0] <= b'9': + us[0] = us[0] * 10 + (ptr[0] - <char>b'0') + else: + break + + ptr += 1 + + # Pad the fraction of second to get millis + if us[0] and ptr - start < 6: + us[0] *= _uspad[ptr - start] + + return ptr + + +cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end): + """ + Parse a timezone from a string, return Python timezone object. + + Modify the buffer pointer to point at the first character after the + timezone parsed. In case of parse error make it NULL. + """ + cdef const char *ptr = bufptr[0] + cdef char sgn = ptr[0] + + # Parse at most three groups of digits + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + + ptr = _parse_date_values(ptr + 1, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + return 0 + + cdef int off = 60 * (60 * vals[0] + vals[1]) + vals[2] + return -off if sgn == b"-" else off + + +cdef object _timezone_from_seconds(int sec, __cache={}): + cdef object pysec = sec + cdef PyObject *ptr = PyDict_GetItem(__cache, pysec) + if ptr != NULL: + return <object>ptr + + delta = cdt.timedelta_new(0, sec, 0) + tz = timezone(delta) + __cache[pysec] = tz + return tz + + +cdef object _get_timestamp_load_error( + pq.PGconn pgconn, const char *data, ex: Optional[Exception] = None +): + s = bytes(data).decode("utf8", "replace") + + def is_overflow(s): + if not s: + return False + + ds = _get_datestyle(pgconn) + if not ds.startswith(b"P"): # Postgres + return len(s.split()[0]) > 10 # date is first token + else: + return len(s.split()[-1]) > 4 # year is last token + + if s == "-infinity" or s.endswith("BC"): + return e.DataError("timestamp too small (before year 1): {s!r}") + elif s == "infinity" or is_overflow(s): + return e.DataError(f"timestamp too large (after year 10K): {s!r}") + else: + return e.DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}") + + +cdef _timezones = {} +_timezones[None] = timezone_utc +_timezones[b"UTC"] = timezone_utc + + +cdef object _timezone_from_connection(pq.PGconn pgconn): + """Return the Python timezone info of the connection's timezone.""" + if pgconn is None: + return timezone_utc + + cdef bytes tzname = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"TimeZone") + cdef PyObject *ptr = PyDict_GetItem(_timezones, tzname) + if ptr != NULL: + return <object>ptr + + sname = tzname.decode() if tzname else "UTC" + try: + zi = ZoneInfo(sname) + except (KeyError, OSError): + logger.warning( + "unknown PostgreSQL timezone: %r; will use UTC", sname + ) + zi = timezone_utc + except Exception as ex: + logger.warning( + "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)", + sname, + type(ex).__name__, + ex, + ) + zi = timezone.utc + + _timezones[tzname] = zi + return zi + + +cdef const char *_get_datestyle(pq.PGconn pgconn): + cdef const char *ds + if pgconn is not None: + ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"DateStyle") + if ds is not NULL and ds[0]: + return ds + + return b"ISO, DMY" + + +cdef const char *_get_intervalstyle(pq.PGconn pgconn): + cdef const char *ds + if pgconn is not None: + ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"IntervalStyle") + if ds is not NULL and ds[0]: + return ds + + return b"postgres" diff --git a/psycopg_c/psycopg_c/types/numeric.pyx b/psycopg_c/psycopg_c/types/numeric.pyx new file mode 100644 index 0000000..893bdc2 --- /dev/null +++ b/psycopg_c/psycopg_c/types/numeric.pyx @@ -0,0 +1,715 @@ +""" +Cython adapters for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython + +from libc.stdint cimport * +from libc.string cimport memcpy, strlen +from cpython.mem cimport PyMem_Free +from cpython.dict cimport PyDict_GetItem, PyDict_SetItem +from cpython.long cimport ( + PyLong_FromString, PyLong_FromLong, PyLong_FromLongLong, + PyLong_FromUnsignedLong, PyLong_AsLongLong) +from cpython.bytes cimport PyBytes_AsStringAndSize +from cpython.float cimport PyFloat_FromDouble, PyFloat_AsDouble +from cpython.unicode cimport PyUnicode_DecodeUTF8 + +from decimal import Decimal, Context, DefaultContext + +from psycopg_c._psycopg cimport endian +from psycopg import errors as e +from psycopg._wrappers import Int2, Int4, Int8, IntNumeric + +cdef extern from "Python.h": + # work around https://github.com/cython/cython/issues/3909 + double PyOS_string_to_double( + const char *s, char **endptr, PyObject *overflow_exception) except? -1.0 + char *PyOS_double_to_string( + double val, char format_code, int precision, int flags, int *ptype + ) except NULL + int Py_DTSF_ADD_DOT_0 + long long PyLong_AsLongLongAndOverflow(object pylong, int *overflow) except? -1 + + # Missing in cpython/unicode.pxd + const char *PyUnicode_AsUTF8(object unicode) except NULL + + +# defined in numutils.c +cdef extern from *: + """ +int pg_lltoa(int64_t value, char *a); +#define MAXINT8LEN 20 + """ + int pg_lltoa(int64_t value, char *a) + const int MAXINT8LEN + + +cdef class _NumberDumper(CDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_int_to_text(obj, rv, offset) + + def quote(self, obj) -> bytearray: + cdef Py_ssize_t length + + rv = PyByteArray_FromStringAndSize("", 0) + if obj >= 0: + length = self.cdump(obj, rv, 0) + else: + PyByteArray_Resize(rv, 23) + rv[0] = b' ' + length = 1 + self.cdump(obj, rv, 1) + + PyByteArray_Resize(rv, length) + return rv + + +@cython.final +cdef class Int2Dumper(_NumberDumper): + + oid = oids.INT2_OID + + +@cython.final +cdef class Int4Dumper(_NumberDumper): + + oid = oids.INT4_OID + + +@cython.final +cdef class Int8Dumper(_NumberDumper): + + oid = oids.INT8_OID + + +@cython.final +cdef class IntNumericDumper(_NumberDumper): + + oid = oids.NUMERIC_OID + + +@cython.final +cdef class Int2BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INT2_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int16_t *buf = <int16_t *>CDumper.ensure_size( + rv, offset, sizeof(int16_t)) + cdef int16_t val = <int16_t>PyLong_AsLongLong(obj) + # swap bytes if needed + cdef uint16_t *ptvar = <uint16_t *>(&val) + buf[0] = endian.htobe16(ptvar[0]) + return sizeof(int16_t) + + +@cython.final +cdef class Int4BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INT4_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int32_t *buf = <int32_t *>CDumper.ensure_size( + rv, offset, sizeof(int32_t)) + cdef int32_t val = <int32_t>PyLong_AsLongLong(obj) + # swap bytes if needed + cdef uint32_t *ptvar = <uint32_t *>(&val) + buf[0] = endian.htobe32(ptvar[0]) + return sizeof(int32_t) + + +@cython.final +cdef class Int8BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INT8_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t *buf = <int64_t *>CDumper.ensure_size( + rv, offset, sizeof(int64_t)) + cdef int64_t val = PyLong_AsLongLong(obj) + # swap bytes if needed + cdef uint64_t *ptvar = <uint64_t *>(&val) + buf[0] = endian.htobe64(ptvar[0]) + return sizeof(int64_t) + + +cdef extern from *: + """ +/* Ratio between number of bits required to store a number and number of pg + * decimal digits required (log(2) / log(10_000)). + */ +#define BIT_PER_PGDIGIT 0.07525749891599529 + +/* decimal digits per Postgres "digit" */ +#define DEC_DIGITS 4 + +#define NUMERIC_POS 0x0000 +#define NUMERIC_NEG 0x4000 +#define NUMERIC_NAN 0xC000 +#define NUMERIC_PINF 0xD000 +#define NUMERIC_NINF 0xF000 +""" + const double BIT_PER_PGDIGIT + const int DEC_DIGITS + const int NUMERIC_POS + const int NUMERIC_NEG + const int NUMERIC_NAN + const int NUMERIC_PINF + const int NUMERIC_NINF + + +@cython.final +cdef class IntNumericBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_int_to_numeric_binary(obj, rv, offset) + + +cdef class IntDumper(_NumberDumper): + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + raise TypeError( + f"{type(self).__name__} is a dispatcher to other dumpers:" + " dump() is not supposed to be called" + ) + + cpdef get_key(self, obj, format): + cdef long long val + cdef int overflow + + val = PyLong_AsLongLongAndOverflow(obj, &overflow) + if overflow: + return IntNumeric + + if INT32_MIN <= obj <= INT32_MAX: + if INT16_MIN <= obj <= INT16_MAX: + return Int2 + else: + return Int4 + else: + if INT64_MIN <= obj <= INT64_MAX: + return Int8 + else: + return IntNumeric + + _int2_dumper = Int2Dumper + _int4_dumper = Int4Dumper + _int8_dumper = Int8Dumper + _int_numeric_dumper = IntNumericDumper + + cpdef upgrade(self, obj, format): + cdef long long val + cdef int overflow + + val = PyLong_AsLongLongAndOverflow(obj, &overflow) + if overflow: + return self._int_numeric_dumper(IntNumeric) + + if INT32_MIN <= obj <= INT32_MAX: + if INT16_MIN <= obj <= INT16_MAX: + return self._int2_dumper(Int2) + else: + return self._int4_dumper(Int4) + else: + if INT64_MIN <= obj <= INT64_MAX: + return self._int8_dumper(Int8) + else: + return self._int_numeric_dumper(IntNumeric) + + +@cython.final +cdef class IntBinaryDumper(IntDumper): + + format = PQ_BINARY + + _int2_dumper = Int2BinaryDumper + _int4_dumper = Int4BinaryDumper + _int8_dumper = Int8BinaryDumper + _int_numeric_dumper = IntNumericBinaryDumper + + +@cython.final +cdef class IntLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + # if the number ends with a 0 we don't need a copy + if data[length] == b'\0': + return PyLong_FromString(data, NULL, 10) + + # Otherwise we have to copy it aside + if length > MAXINT8LEN: + raise ValueError("string too big for an int") + + cdef char[MAXINT8LEN + 1] buf + memcpy(buf, data, length) + buf[length] = 0 + return PyLong_FromString(buf, NULL, 10) + + + +@cython.final +cdef class Int2BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromLong(<int16_t>endian.be16toh((<uint16_t *>data)[0])) + + +@cython.final +cdef class Int4BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromLong(<int32_t>endian.be32toh((<uint32_t *>data)[0])) + + +@cython.final +cdef class Int8BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromLongLong(<int64_t>endian.be64toh((<uint64_t *>data)[0])) + + +@cython.final +cdef class OidBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromUnsignedLong(endian.be32toh((<uint32_t *>data)[0])) + + +cdef class _FloatDumper(CDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef double d = PyFloat_AsDouble(obj) + cdef char *out = PyOS_double_to_string( + d, b'r', 0, Py_DTSF_ADD_DOT_0, NULL) + cdef Py_ssize_t length = strlen(out) + cdef char *tgt = CDumper.ensure_size(rv, offset, length) + memcpy(tgt, out, length) + PyMem_Free(out) + return length + + def quote(self, obj) -> bytes: + value = bytes(self.dump(obj)) + cdef PyObject *ptr = PyDict_GetItem(_special_float, value) + if ptr != NULL: + return <object>ptr + + return value if obj >= 0 else b" " + value + +cdef dict _special_float = { + b"inf": b"'Infinity'::float8", + b"-inf": b"'-Infinity'::float8", + b"nan": b"'NaN'::float8", +} + + +@cython.final +cdef class FloatDumper(_FloatDumper): + + oid = oids.FLOAT8_OID + + +@cython.final +cdef class Float4Dumper(_FloatDumper): + + oid = oids.FLOAT4_OID + + +@cython.final +cdef class FloatBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.FLOAT8_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef double d = PyFloat_AsDouble(obj) + cdef uint64_t *intptr = <uint64_t *>&d + cdef uint64_t *buf = <uint64_t *>CDumper.ensure_size( + rv, offset, sizeof(uint64_t)) + buf[0] = endian.htobe64(intptr[0]) + return sizeof(uint64_t) + + +@cython.final +cdef class Float4BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.FLOAT4_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef float f = <float>PyFloat_AsDouble(obj) + cdef uint32_t *intptr = <uint32_t *>&f + cdef uint32_t *buf = <uint32_t *>CDumper.ensure_size( + rv, offset, sizeof(uint32_t)) + buf[0] = endian.htobe32(intptr[0]) + return sizeof(uint32_t) + + +@cython.final +cdef class FloatLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + cdef char *endptr + cdef double d = PyOS_string_to_double( + data, &endptr, <PyObject *>OverflowError) + return PyFloat_FromDouble(d) + + +@cython.final +cdef class Float4BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef uint32_t asint = endian.be32toh((<uint32_t *>data)[0]) + # avoid warning: + # dereferencing type-punned pointer will break strict-aliasing rules + cdef char *swp = <char *>&asint + return PyFloat_FromDouble((<float *>swp)[0]) + + +@cython.final +cdef class Float8BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef uint64_t asint = endian.be64toh((<uint64_t *>data)[0]) + cdef char *swp = <char *>&asint + return PyFloat_FromDouble((<double *>swp)[0]) + + +@cython.final +cdef class DecimalDumper(CDumper): + + format = PQ_TEXT + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_decimal_to_text(obj, rv, offset) + + def quote(self, obj) -> bytes: + value = bytes(self.dump(obj)) + cdef PyObject *ptr = PyDict_GetItem(_special_decimal, value) + if ptr != NULL: + return <object>ptr + + return value if obj >= 0 else b" " + value + +cdef dict _special_decimal = { + b"Infinity": b"'Infinity'::numeric", + b"-Infinity": b"'-Infinity'::numeric", + b"NaN": b"'NaN'::numeric", +} + + +@cython.final +cdef class NumericLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + s = PyUnicode_DecodeUTF8(<char *>data, length, NULL) + return Decimal(s) + + +cdef dict _decimal_special = { + NUMERIC_NAN: Decimal("NaN"), + NUMERIC_PINF: Decimal("Infinity"), + NUMERIC_NINF: Decimal("-Infinity"), +} + +cdef dict _contexts = {} +for _i in range(DefaultContext.prec): + _contexts[_i] = DefaultContext + + +@cython.final +cdef class NumericBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + + cdef uint16_t *data16 = <uint16_t *>data + cdef uint16_t ndigits = endian.be16toh(data16[0]) + cdef int16_t weight = <int16_t>endian.be16toh(data16[1]) + cdef uint16_t sign = endian.be16toh(data16[2]) + cdef uint16_t dscale = endian.be16toh(data16[3]) + cdef int shift + cdef int i + cdef PyObject *pctx + cdef object key + + if sign == NUMERIC_POS or sign == NUMERIC_NEG: + if length != (4 + ndigits) * sizeof(uint16_t): + raise e.DataError("bad ndigits in numeric binary representation") + + val = 0 + for i in range(ndigits): + val *= 10_000 + val += endian.be16toh(data16[i + 4]) + + shift = dscale - (ndigits - weight - 1) * DEC_DIGITS + + key = (weight + 2) * DEC_DIGITS + dscale + pctx = PyDict_GetItem(_contexts, key) + if pctx == NULL: + ctx = Context(prec=key) + PyDict_SetItem(_contexts, key, ctx) + pctx = <PyObject *>ctx + + return ( + Decimal(val if sign == NUMERIC_POS else -val) + .scaleb(-dscale, <object>pctx) + .shift(shift, <object>pctx) + ) + else: + try: + return _decimal_special[sign] + except KeyError: + raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") + + +@cython.final +cdef class DecimalBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_decimal_to_numeric_binary(obj, rv, offset) + + +@cython.final +cdef class NumericDumper(CDumper): + + format = PQ_TEXT + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + if isinstance(obj, int): + return dump_int_to_text(obj, rv, offset) + else: + return dump_decimal_to_text(obj, rv, offset) + + +@cython.final +cdef class NumericBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + if isinstance(obj, int): + return dump_int_to_numeric_binary(obj, rv, offset) + else: + return dump_decimal_to_numeric_binary(obj, rv, offset) + + +cdef Py_ssize_t dump_decimal_to_text(obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *src + cdef Py_ssize_t length + cdef char *buf + + b = bytes(str(obj), "utf-8") + PyBytes_AsStringAndSize(b, &src, &length) + + if src[0] != b's': + buf = CDumper.ensure_size(rv, offset, length) + memcpy(buf, src, length) + + else: # convert sNaN to NaN + length = 3 # NaN + buf = CDumper.ensure_size(rv, offset, length) + memcpy(buf, b"NaN", length) + + return length + + +cdef extern from *: + """ +/* Weights of py digits into a pg digit according to their positions. */ +static const int pydigit_weights[] = {1000, 100, 10, 1}; +""" + const int[4] pydigit_weights + + +@cython.cdivision(True) +cdef Py_ssize_t dump_decimal_to_numeric_binary( + obj, bytearray rv, Py_ssize_t offset +) except -1: + + # TODO: this implementation is about 30% slower than the text dump. + # This might be probably optimised by accessing the C structure of + # the Decimal object, if available, which would save the creation of + # several intermediate Python objects (the DecimalTuple, the digits + # tuple, and then accessing them). + + cdef object t = obj.as_tuple() + cdef int sign = t[0] + cdef tuple digits = t[1] + cdef uint16_t *buf + cdef Py_ssize_t length + + cdef object pyexp = t[2] + cdef const char *bexp + if not isinstance(pyexp, int): + # Handle inf, nan + length = 4 * sizeof(uint16_t) + buf = <uint16_t *>CDumper.ensure_size(rv, offset, length) + buf[0] = 0 + buf[1] = 0 + buf[3] = 0 + bexp = PyUnicode_AsUTF8(pyexp) + if bexp[0] == b'n' or bexp[0] == b'N': + buf[2] = endian.htobe16(NUMERIC_NAN) + elif bexp[0] == b'F': + if sign: + buf[2] = endian.htobe16(NUMERIC_NINF) + else: + buf[2] = endian.htobe16(NUMERIC_PINF) + else: + raise e.DataError(f"unexpected decimal exponent: {pyexp}") + return length + + cdef int exp = pyexp + cdef uint16_t ndigits = <uint16_t>len(digits) + + # Find the last nonzero digit + cdef int nzdigits = ndigits + while nzdigits > 0 and digits[nzdigits - 1] == 0: + nzdigits -= 1 + + cdef uint16_t dscale + if exp <= 0: + dscale = -exp + else: + dscale = 0 + # align the py digits to the pg digits if there's some py exponent + ndigits += exp % DEC_DIGITS + + if nzdigits == 0: + length = 4 * sizeof(uint16_t) + buf = <uint16_t *>CDumper.ensure_size(rv, offset, length) + buf[0] = 0 # ndigits + buf[1] = 0 # weight + buf[2] = endian.htobe16(NUMERIC_POS) # sign + buf[3] = endian.htobe16(dscale) + return length + + # Equivalent of 0-padding left to align the py digits to the pg digits + # but without changing the digits tuple. + cdef int wi = 0 + cdef int mod = (ndigits - dscale) % DEC_DIGITS + if mod < 0: + # the difference between C and Py % operator + mod += 4 + if mod: + wi = DEC_DIGITS - mod + ndigits += wi + + cdef int tmp = nzdigits + wi + cdef int pgdigits = tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1) + length = (pgdigits + 4) * sizeof(uint16_t) + buf = <uint16_t*>CDumper.ensure_size(rv, offset, length) + buf[0] = endian.htobe16(pgdigits) + buf[1] = endian.htobe16(<int16_t>((ndigits + exp) // DEC_DIGITS - 1)) + buf[2] = endian.htobe16(NUMERIC_NEG) if sign else endian.htobe16(NUMERIC_POS) + buf[3] = endian.htobe16(dscale) + + cdef uint16_t pgdigit = 0 + cdef int bi = 4 + for i in range(nzdigits): + pgdigit += pydigit_weights[wi] * <int>(digits[i]) + wi += 1 + if wi >= DEC_DIGITS: + buf[bi] = endian.htobe16(pgdigit) + pgdigit = wi = 0 + bi += 1 + + if pgdigit: + buf[bi] = endian.htobe16(pgdigit) + + return length + + +cdef Py_ssize_t dump_int_to_text(obj, bytearray rv, Py_ssize_t offset) except -1: + cdef long long val + cdef int overflow + cdef char *buf + cdef char *src + cdef Py_ssize_t length + + # Ensure an int or a subclass. The 'is' type check is fast. + # Passing a float must give an error, but passing an Enum should work. + if type(obj) is not int and not isinstance(obj, int): + raise e.DataError(f"integer expected, got {type(obj).__name__!r}") + + val = PyLong_AsLongLongAndOverflow(obj, &overflow) + if not overflow: + buf = CDumper.ensure_size(rv, offset, MAXINT8LEN + 1) + length = pg_lltoa(val, buf) + else: + b = bytes(str(obj), "utf-8") + PyBytes_AsStringAndSize(b, &src, &length) + buf = CDumper.ensure_size(rv, offset, length) + memcpy(buf, src, length) + + return length + + +cdef Py_ssize_t dump_int_to_numeric_binary(obj, bytearray rv, Py_ssize_t offset) except -1: + # Calculate the number of PG digits required to store the number + cdef uint16_t ndigits + ndigits = <uint16_t>((<int>obj.bit_length()) * BIT_PER_PGDIGIT) + 1 + + cdef uint16_t sign = NUMERIC_POS + if obj < 0: + sign = NUMERIC_NEG + obj = -obj + + cdef Py_ssize_t length = sizeof(uint16_t) * (ndigits + 4) + cdef uint16_t *buf + buf = <uint16_t *><void *>CDumper.ensure_size(rv, offset, length) + buf[0] = endian.htobe16(ndigits) + buf[1] = endian.htobe16(ndigits - 1) # weight + buf[2] = endian.htobe16(sign) + buf[3] = 0 # dscale + + cdef int i = 4 + ndigits - 1 + cdef uint16_t rem + while obj: + rem = obj % 10000 + obj //= 10000 + buf[i] = endian.htobe16(rem) + i -= 1 + while i > 3: + buf[i] = 0 + i -= 1 + + return length diff --git a/psycopg_c/psycopg_c/types/numutils.c b/psycopg_c/psycopg_c/types/numutils.c new file mode 100644 index 0000000..4be7108 --- /dev/null +++ b/psycopg_c/psycopg_c/types/numutils.c @@ -0,0 +1,243 @@ +/* + * Utilities to deal with numbers. + * + * Copyright (C) 2020 The Psycopg Team + * Portions Copyright (c) 1996-2020, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + */ + +#include <stdint.h> +#include <string.h> + +#include "pg_config.h" + + +/* + * 64-bit integers + */ +#ifdef HAVE_LONG_INT_64 +/* Plain "long int" fits, use it */ + +# ifndef HAVE_INT64 +typedef long int int64; +# endif +# ifndef HAVE_UINT64 +typedef unsigned long int uint64; +# endif +# define INT64CONST(x) (x##L) +# define UINT64CONST(x) (x##UL) +#elif defined(HAVE_LONG_LONG_INT_64) +/* We have working support for "long long int", use that */ + +# ifndef HAVE_INT64 +typedef long long int int64; +# endif +# ifndef HAVE_UINT64 +typedef unsigned long long int uint64; +# endif +# define INT64CONST(x) (x##LL) +# define UINT64CONST(x) (x##ULL) +#else +/* neither HAVE_LONG_INT_64 nor HAVE_LONG_LONG_INT_64 */ +# error must have a working 64-bit integer datatype +#endif + + +#ifndef HAVE__BUILTIN_CLZ +static const uint8_t pg_leftmost_one_pos[256] = { + 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 +}; +#endif + +static const char DIGIT_TABLE[200] = { + '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', + '7', '0', '8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', + '1', '5', '1', '6', '1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2', + '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9', + '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3', + '7', '3', '8', '3', '9', '4', '0', '4', '1', '4', '2', '4', '3', '4', '4', + '4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0', '5', '1', '5', + '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9', + '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', + '7', '6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', + '7', '5', '7', '6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8', + '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9', + '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9', + '7', '9', '8', '9', '9' +}; + + +/* + * pg_leftmost_one_pos64 + * As above, but for a 64-bit word. + */ +static inline int +pg_leftmost_one_pos64(uint64_t word) +{ +#ifdef HAVE__BUILTIN_CLZ +#if defined(HAVE_LONG_INT_64) + return 63 - __builtin_clzl(word); +#elif defined(HAVE_LONG_LONG_INT_64) + return 63 - __builtin_clzll(word); +#else +#error must have a working 64-bit integer datatype +#endif +#else /* !HAVE__BUILTIN_CLZ */ + int shift = 64 - 8; + + while ((word >> shift) == 0) + shift -= 8; + + return shift + pg_leftmost_one_pos[(word >> shift) & 255]; +#endif /* HAVE__BUILTIN_CLZ */ +} + + +static inline int +decimalLength64(const uint64_t v) +{ + int t; + static const uint64_t PowersOfTen[] = { + UINT64CONST(1), UINT64CONST(10), + UINT64CONST(100), UINT64CONST(1000), + UINT64CONST(10000), UINT64CONST(100000), + UINT64CONST(1000000), UINT64CONST(10000000), + UINT64CONST(100000000), UINT64CONST(1000000000), + UINT64CONST(10000000000), UINT64CONST(100000000000), + UINT64CONST(1000000000000), UINT64CONST(10000000000000), + UINT64CONST(100000000000000), UINT64CONST(1000000000000000), + UINT64CONST(10000000000000000), UINT64CONST(100000000000000000), + UINT64CONST(1000000000000000000), UINT64CONST(10000000000000000000) + }; + + /* + * Compute base-10 logarithm by dividing the base-2 logarithm by a + * good-enough approximation of the base-2 logarithm of 10 + */ + t = (pg_leftmost_one_pos64(v) + 1) * 1233 / 4096; + return t + (v >= PowersOfTen[t]); +} + + +/* + * Get the decimal representation, not NUL-terminated, and return the length of + * same. Caller must ensure that a points to at least MAXINT8LEN bytes. + */ +int +pg_ulltoa_n(uint64_t value, char *a) +{ + int olength, + i = 0; + uint32_t value2; + + /* Degenerate case */ + if (value == 0) + { + *a = '0'; + return 1; + } + + olength = decimalLength64(value); + + /* Compute the result string. */ + while (value >= 100000000) + { + const uint64_t q = value / 100000000; + uint32_t value2 = (uint32_t) (value - 100000000 * q); + + const uint32_t c = value2 % 10000; + const uint32_t d = value2 / 10000; + const uint32_t c0 = (c % 100) << 1; + const uint32_t c1 = (c / 100) << 1; + const uint32_t d0 = (d % 100) << 1; + const uint32_t d1 = (d / 100) << 1; + + char *pos = a + olength - i; + + value = q; + + memcpy(pos - 2, DIGIT_TABLE + c0, 2); + memcpy(pos - 4, DIGIT_TABLE + c1, 2); + memcpy(pos - 6, DIGIT_TABLE + d0, 2); + memcpy(pos - 8, DIGIT_TABLE + d1, 2); + i += 8; + } + + /* Switch to 32-bit for speed */ + value2 = (uint32_t) value; + + if (value2 >= 10000) + { + const uint32_t c = value2 - 10000 * (value2 / 10000); + const uint32_t c0 = (c % 100) << 1; + const uint32_t c1 = (c / 100) << 1; + + char *pos = a + olength - i; + + value2 /= 10000; + + memcpy(pos - 2, DIGIT_TABLE + c0, 2); + memcpy(pos - 4, DIGIT_TABLE + c1, 2); + i += 4; + } + if (value2 >= 100) + { + const uint32_t c = (value2 % 100) << 1; + char *pos = a + olength - i; + + value2 /= 100; + + memcpy(pos - 2, DIGIT_TABLE + c, 2); + i += 2; + } + if (value2 >= 10) + { + const uint32_t c = value2 << 1; + char *pos = a + olength - i; + + memcpy(pos - 2, DIGIT_TABLE + c, 2); + } + else + *a = (char) ('0' + value2); + + return olength; +} + +/* + * pg_lltoa: converts a signed 64-bit integer to its string representation and + * returns strlen(a). + * + * Caller must ensure that 'a' points to enough memory to hold the result + * (at least MAXINT8LEN + 1 bytes, counting a leading sign and trailing NUL). + */ +int +pg_lltoa(int64_t value, char *a) +{ + uint64_t uvalue = value; + int len = 0; + + if (value < 0) + { + uvalue = (uint64_t) 0 - uvalue; + a[len++] = '-'; + } + + len += pg_ulltoa_n(uvalue, a + len); + a[len] = '\0'; + return len; +} diff --git a/psycopg_c/psycopg_c/types/string.pyx b/psycopg_c/psycopg_c/types/string.pyx new file mode 100644 index 0000000..da18b01 --- /dev/null +++ b/psycopg_c/psycopg_c/types/string.pyx @@ -0,0 +1,315 @@ +""" +Cython adapters for textual types. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython + +from libc.string cimport memcpy, memchr +from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize +from cpython.unicode cimport ( + PyUnicode_AsEncodedString, + PyUnicode_AsUTF8String, + PyUnicode_CheckExact, + PyUnicode_Decode, + PyUnicode_DecodeUTF8, +) + +from psycopg_c.pq cimport libpq, Escaping, _buffer_as_string_and_size + +from psycopg import errors as e +from psycopg._encodings import pg2pyenc + +cdef extern from "Python.h": + const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL + + +cdef class _BaseStrDumper(CDumper): + cdef int is_utf8 + cdef char *encoding + cdef bytes _bytes_encoding # needed to keep `encoding` alive + + def __cinit__(self, cls, context: Optional[AdaptContext] = None): + + self.is_utf8 = 0 + self.encoding = "utf-8" + cdef const char *pgenc + + if self._pgconn is not None: + pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding") + if pgenc == NULL or pgenc == b"UTF8": + self._bytes_encoding = b"utf-8" + self.is_utf8 = 1 + else: + self._bytes_encoding = pg2pyenc(pgenc).encode() + if self._bytes_encoding == b"ascii": + self.is_utf8 = 1 + self.encoding = PyBytes_AsString(self._bytes_encoding) + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + # the server will raise DataError subclass if the string contains 0x00 + cdef Py_ssize_t size; + cdef const char *src + + if self.is_utf8: + # Probably the fastest path, but doesn't work with subclasses + if PyUnicode_CheckExact(obj): + src = PyUnicode_AsUTF8AndSize(obj, &size) + else: + b = PyUnicode_AsUTF8String(obj) + PyBytes_AsStringAndSize(b, <char **>&src, &size) + else: + b = PyUnicode_AsEncodedString(obj, self.encoding, NULL) + PyBytes_AsStringAndSize(b, <char **>&src, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +cdef class _StrBinaryDumper(_BaseStrDumper): + + format = PQ_BINARY + + +@cython.final +cdef class StrBinaryDumper(_StrBinaryDumper): + + oid = oids.TEXT_OID + + +@cython.final +cdef class StrBinaryDumperVarchar(_StrBinaryDumper): + + oid = oids.VARCHAR_OID + + +@cython.final +cdef class StrBinaryDumperName(_StrBinaryDumper): + + oid = oids.NAME_OID + + +cdef class _StrDumper(_BaseStrDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size = StrBinaryDumper.cdump(self, obj, rv, offset) + + # Like the binary dump, but check for 0, or the string will be truncated + cdef const char *buf = PyByteArray_AS_STRING(rv) + if NULL != memchr(buf + offset, 0x00, size): + raise e.DataError( + "PostgreSQL text fields cannot contain NUL (0x00) bytes" + ) + return size + + +@cython.final +cdef class StrDumper(_StrDumper): + + oid = oids.TEXT_OID + + +@cython.final +cdef class StrDumperVarchar(_StrDumper): + + oid = oids.VARCHAR_OID + + +@cython.final +cdef class StrDumperName(_StrDumper): + + oid = oids.NAME_OID + + +@cython.final +cdef class StrDumperUnknown(_StrDumper): + pass + + +cdef class _TextLoader(CLoader): + + format = PQ_TEXT + + cdef int is_utf8 + cdef char *encoding + cdef bytes _bytes_encoding # needed to keep `encoding` alive + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + self.is_utf8 = 0 + self.encoding = "utf-8" + cdef const char *pgenc + + if self._pgconn is not None: + pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding") + if pgenc == NULL or pgenc == b"UTF8": + self._bytes_encoding = b"utf-8" + self.is_utf8 = 1 + else: + self._bytes_encoding = pg2pyenc(pgenc).encode() + + if pgenc == b"SQL_ASCII": + self.encoding = NULL + else: + self.encoding = PyBytes_AsString(self._bytes_encoding) + + cdef object cload(self, const char *data, size_t length): + if self.is_utf8: + return PyUnicode_DecodeUTF8(<char *>data, length, NULL) + elif self.encoding: + return PyUnicode_Decode(<char *>data, length, self.encoding, NULL) + else: + return data[:length] + +@cython.final +cdef class TextLoader(_TextLoader): + + format = PQ_TEXT + + +@cython.final +cdef class TextBinaryLoader(_TextLoader): + + format = PQ_BINARY + + +@cython.final +cdef class BytesDumper(CDumper): + + format = PQ_TEXT + oid = oids.BYTEA_OID + + # 0: not set, 1: just single "'" quote, 3: " E'" quote + cdef int _qplen + + def __cinit__(self): + self._qplen = 0 + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + + _buffer_as_string_and_size(obj, &ptr, &length) + + if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: + out = libpq.PQescapeByteaConn( + self._pgconn._pgconn_ptr, <unsigned char *>ptr, length, &len_out) + else: + out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out) + + if out is NULL: + raise MemoryError( + f"couldn't allocate for escape_bytea of {length} bytes" + ) + + len_out -= 1 # out includes final 0 + cdef char *buf = CDumper.ensure_size(rv, offset, len_out) + memcpy(buf, out, len_out) + libpq.PQfreemem(out) + return len_out + + def quote(self, obj): + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + cdef const char *scs + + escaped = self.dump(obj) + _buffer_as_string_and_size(escaped, &ptr, &length) + + rv = PyByteArray_FromStringAndSize("", 0) + + # We cannot use the base quoting because escape_bytea already returns + # the quotes content. if scs is off it will escape the backslashes in + # the format, otherwise it won't, but it doesn't tell us what quotes to + # use. + if self._pgconn is not None: + if not self._qplen: + scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, + b"standard_conforming_strings") + if scs and scs[0] == b'o' and scs[1] == b"n": # == "on" + self._qplen = 1 + else: + self._qplen = 3 + + PyByteArray_Resize(rv, length + self._qplen + 1) # Include quotes + ptr_out = PyByteArray_AS_STRING(rv) + if self._qplen == 1: + ptr_out[0] = b"'" + else: + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + self._qplen, ptr, length) + ptr_out[length + self._qplen] = b"'" + return rv + + # We don't have a connection, so someone is using us to generate a file + # to use off-line or something like that. PQescapeBytea, like its + # string counterpart, is not predictable whether it will escape + # backslashes. + PyByteArray_Resize(rv, length + 4) # Include quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + 3, ptr, length) + ptr_out[length + 3] = b"'" + + esc = Escaping() + if esc.escape_bytea(b"\x00") == b"\\000": + rv = bytes(rv).replace(b"\\", b"\\\\") + + return rv + + +@cython.final +cdef class BytesBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.BYTEA_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *src + cdef Py_ssize_t size; + _buffer_as_string_and_size(obj, &src, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class ByteaLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + cdef size_t len_out + cdef unsigned char *out = libpq.PQunescapeBytea( + <const unsigned char *>data, &len_out) + if out is NULL: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = out[:len_out] + libpq.PQfreemem(out) + return rv + + +@cython.final +cdef class ByteaBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return data[:length] diff --git a/psycopg_c/psycopg_c/version.py b/psycopg_c/psycopg_c/version.py new file mode 100644 index 0000000..5c989c2 --- /dev/null +++ b/psycopg_c/psycopg_c/version.py @@ -0,0 +1,11 @@ +""" +psycopg-c distribution version file. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Use a versioning scheme as defined in +# https://www.python.org/dev/peps/pep-0440/ +__version__ = "3.1.7" + +# also change psycopg/psycopg/version.py accordingly. diff --git a/psycopg_c/pyproject.toml b/psycopg_c/pyproject.toml new file mode 100644 index 0000000..f0d7a3f --- /dev/null +++ b/psycopg_c/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=49.2.0", "wheel>=0.37", "Cython>=3.0.0a11"] +build-backend = "setuptools.build_meta" diff --git a/psycopg_c/setup.cfg b/psycopg_c/setup.cfg new file mode 100644 index 0000000..6c5c93c --- /dev/null +++ b/psycopg_c/setup.cfg @@ -0,0 +1,57 @@ +[metadata] +name = psycopg-c +description = PostgreSQL database adapter for Python -- C optimisation distribution +url = https://psycopg.org/psycopg3/ +author = Daniele Varrazzo +author_email = daniele.varrazzo@gmail.com +license = GNU Lesser General Public License v3 (LGPLv3) + +project_urls = + Homepage = https://psycopg.org/ + Code = https://github.com/psycopg/psycopg + Issue Tracker = https://github.com/psycopg/psycopg/issues + Download = https://pypi.org/project/psycopg-c/ + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) + Operating System :: MacOS :: MacOS X + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Database + Topic :: Database :: Front-Ends + Topic :: Software Development + Topic :: Software Development :: Libraries :: Python Modules + +long_description = file: README.rst +long_description_content_type = text/x-rst +license_files = LICENSE.txt + +[options] +python_requires = >= 3.7 +setup_requires = Cython >= 3.0.0a11 +packages = find: +zip_safe = False + +[options.package_data] +# NOTE: do not include .pyx files: they shouldn't be in the sdist +# package, so that build is only performed from the .c files (which are +# distributed instead). +psycopg_c = + py.typed + *.pyi + *.pxd + _psycopg/*.pxd + pq/*.pxd + +# In the psycopg-binary distribution don't include cython-related files. +psycopg_binary = + py.typed + *.pyi diff --git a/psycopg_c/setup.py b/psycopg_c/setup.py new file mode 100644 index 0000000..c6da3a1 --- /dev/null +++ b/psycopg_c/setup.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +PostgreSQL database adapter for Python - optimisation package +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +import sys +import subprocess as sp + +from setuptools import setup, Extension +from distutils.command.build_ext import build_ext +from distutils import log + +# Move to the directory of setup.py: executing this file from another location +# (e.g. from the project root) will fail +here = os.path.abspath(os.path.dirname(__file__)) +if os.path.abspath(os.getcwd()) != here: + os.chdir(here) + +with open("psycopg_c/version.py") as f: + data = f.read() + m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data) + if m is None: + raise Exception(f"cannot find version in {f.name}") + version = m.group(1) + + +def get_config(what: str) -> str: + pg_config = "pg_config" + try: + out = sp.run([pg_config, f"--{what}"], stdout=sp.PIPE, check=True) + except Exception as e: + log.error(f"couldn't run {pg_config!r} --{what}: %s", e) + raise + else: + return out.stdout.strip().decode() + + +class psycopg_build_ext(build_ext): + def finalize_options(self) -> None: + self._setup_ext_build() + super().finalize_options() + + def _setup_ext_build(self) -> None: + cythonize = None + + # In the sdist there are not .pyx, only c, so we don't need Cython. + # Otherwise Cython is a requirement and it is used to compile pyx to c. + if os.path.exists("psycopg_c/_psycopg.pyx"): + from Cython.Build import cythonize + + # Add include and lib dir for the libpq. + includedir = get_config("includedir") + libdir = get_config("libdir") + for ext in self.distribution.ext_modules: + ext.include_dirs.append(includedir) + ext.library_dirs.append(libdir) + + if sys.platform == "win32": + # For __imp_htons and others + ext.libraries.append("ws2_32") + + if cythonize is not None: + for ext in self.distribution.ext_modules: + for i in range(len(ext.sources)): + base, fext = os.path.splitext(ext.sources[i]) + if fext == ".c" and os.path.exists(base + ".pyx"): + ext.sources[i] = base + ".pyx" + + self.distribution.ext_modules = cythonize( + self.distribution.ext_modules, + language_level=3, + compiler_directives={ + "always_allow_keywords": False, + }, + annotate=False, # enable to get an html view of the C module + ) + else: + self.distribution.ext_modules = [pgext, pqext] + + +# MSVC requires an explicit "libpq" +libpq = "pq" if sys.platform != "win32" else "libpq" + +# Some details missing, to be finished by psycopg_build_ext.finalize_options +pgext = Extension( + "psycopg_c._psycopg", + [ + "psycopg_c/_psycopg.c", + "psycopg_c/types/numutils.c", + ], + libraries=[libpq], + include_dirs=[], +) + +pqext = Extension( + "psycopg_c.pq", + ["psycopg_c/pq.c"], + libraries=[libpq], + include_dirs=[], +) + +setup( + version=version, + ext_modules=[pgext, pqext], + cmdclass={"build_ext": psycopg_build_ext}, +) diff --git a/psycopg_pool/.flake8 b/psycopg_pool/.flake8 new file mode 100644 index 0000000..2ae629c --- /dev/null +++ b/psycopg_pool/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 diff --git a/psycopg_pool/LICENSE.txt b/psycopg_pool/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/psycopg_pool/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/psycopg_pool/README.rst b/psycopg_pool/README.rst new file mode 100644 index 0000000..6e6b32c --- /dev/null +++ b/psycopg_pool/README.rst @@ -0,0 +1,24 @@ +Psycopg 3: PostgreSQL database adapter for Python - Connection Pool +=================================================================== + +This distribution contains the optional connection pool package +`psycopg_pool`__. + +.. __: https://www.psycopg.org/psycopg3/docs/advanced/pool.html + +This package is kept separate from the main ``psycopg`` package because it is +likely that it will follow a different release cycle. + +You can also install this package using:: + + pip install "psycopg[pool]" + +Please read `the project readme`__ and `the installation documentation`__ for +more details. + +.. __: https://github.com/psycopg/psycopg#readme +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #installing-the-connection-pool + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg_pool/psycopg_pool/__init__.py b/psycopg_pool/psycopg_pool/__init__.py new file mode 100644 index 0000000..e4d975f --- /dev/null +++ b/psycopg_pool/psycopg_pool/__init__.py @@ -0,0 +1,22 @@ +""" +psycopg connection pool package +""" + +# Copyright (C) 2021 The Psycopg Team + +from .pool import ConnectionPool +from .pool_async import AsyncConnectionPool +from .null_pool import NullConnectionPool +from .null_pool_async import AsyncNullConnectionPool +from .errors import PoolClosed, PoolTimeout, TooManyRequests +from .version import __version__ as __version__ # noqa: F401 + +__all__ = [ + "AsyncConnectionPool", + "AsyncNullConnectionPool", + "ConnectionPool", + "NullConnectionPool", + "PoolClosed", + "PoolTimeout", + "TooManyRequests", +] diff --git a/psycopg_pool/psycopg_pool/_compat.py b/psycopg_pool/psycopg_pool/_compat.py new file mode 100644 index 0000000..9fb2b9b --- /dev/null +++ b/psycopg_pool/psycopg_pool/_compat.py @@ -0,0 +1,51 @@ +""" +compatibility functions for different Python versions +""" + +# Copyright (C) 2021 The Psycopg Team + +import sys +import asyncio +from typing import Any, Awaitable, Generator, Optional, Union, Type, TypeVar +from typing_extensions import TypeAlias + +import psycopg.errors as e + +T = TypeVar("T") +FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]] + +if sys.version_info >= (3, 8): + create_task = asyncio.create_task + Task = asyncio.Task + +else: + + def create_task( + coro: FutureT[T], name: Optional[str] = None + ) -> "asyncio.Future[T]": + return asyncio.create_task(coro) + + Task = asyncio.Future + +if sys.version_info >= (3, 9): + from collections import Counter, deque as Deque +else: + from typing import Counter, Deque + +__all__ = [ + "Counter", + "Deque", + "Task", + "create_task", +] + +# Workaround for psycopg < 3.0.8. +# Timeout on NullPool connection mignt not work correctly. +try: + ConnectionTimeout: Type[e.OperationalError] = e.ConnectionTimeout +except AttributeError: + + class DummyConnectionTimeout(e.OperationalError): + pass + + ConnectionTimeout = DummyConnectionTimeout diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py new file mode 100644 index 0000000..298ea68 --- /dev/null +++ b/psycopg_pool/psycopg_pool/base.py @@ -0,0 +1,230 @@ +""" +psycopg connection pool base class and functionalities. +""" + +# Copyright (C) 2021 The Psycopg Team + +from time import monotonic +from random import random +from typing import Any, Callable, Dict, Generic, Optional, Tuple + +from psycopg import errors as e +from psycopg.abc import ConnectionType + +from .errors import PoolClosed +from ._compat import Counter, Deque + + +class BasePool(Generic[ConnectionType]): + + # Used to generate pool names + _num_pool = 0 + + # Stats keys + _POOL_MIN = "pool_min" + _POOL_MAX = "pool_max" + _POOL_SIZE = "pool_size" + _POOL_AVAILABLE = "pool_available" + _REQUESTS_WAITING = "requests_waiting" + _REQUESTS_NUM = "requests_num" + _REQUESTS_QUEUED = "requests_queued" + _REQUESTS_WAIT_MS = "requests_wait_ms" + _REQUESTS_ERRORS = "requests_errors" + _USAGE_MS = "usage_ms" + _RETURNS_BAD = "returns_bad" + _CONNECTIONS_NUM = "connections_num" + _CONNECTIONS_MS = "connections_ms" + _CONNECTIONS_ERRORS = "connections_errors" + _CONNECTIONS_LOST = "connections_lost" + + def __init__( + self, + conninfo: str = "", + *, + kwargs: Optional[Dict[str, Any]] = None, + min_size: int = 4, + max_size: Optional[int] = None, + open: bool = True, + name: Optional[str] = None, + timeout: float = 30.0, + max_waiting: int = 0, + max_lifetime: float = 60 * 60.0, + max_idle: float = 10 * 60.0, + reconnect_timeout: float = 5 * 60.0, + reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]] = None, + num_workers: int = 3, + ): + min_size, max_size = self._check_size(min_size, max_size) + + if not name: + num = BasePool._num_pool = BasePool._num_pool + 1 + name = f"pool-{num}" + + if num_workers < 1: + raise ValueError("num_workers must be at least 1") + + self.conninfo = conninfo + self.kwargs: Dict[str, Any] = kwargs or {} + self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None] + self._reconnect_failed = reconnect_failed or (lambda pool: None) + self.name = name + self._min_size = min_size + self._max_size = max_size + self.timeout = timeout + self.max_waiting = max_waiting + self.reconnect_timeout = reconnect_timeout + self.max_lifetime = max_lifetime + self.max_idle = max_idle + self.num_workers = num_workers + + self._nconns = min_size # currently in the pool, out, being prepared + self._pool = Deque[ConnectionType]() + self._stats = Counter[str]() + + # Min number of connections in the pool in a max_idle unit of time. + # It is reset periodically by the ShrinkPool scheduled task. + # It is used to shrink back the pool if maxcon > min_size and extra + # connections have been acquired, if we notice that in the last + # max_idle interval they weren't all used. + self._nconns_min = min_size + + # Flag to allow the pool to grow only one connection at time. In case + # of spike, if threads are allowed to grow in parallel and connection + # time is slow, there won't be any thread available to return the + # connections to the pool. + self._growing = False + + self._opened = False + self._closed = True + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__module__}.{self.__class__.__name__}" + f" {self.name!r} at 0x{id(self):x}>" + ) + + @property + def min_size(self) -> int: + return self._min_size + + @property + def max_size(self) -> int: + return self._max_size + + @property + def closed(self) -> bool: + """`!True` if the pool is closed.""" + return self._closed + + def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]: + if max_size is None: + max_size = min_size + + if min_size < 0: + raise ValueError("min_size cannot be negative") + if max_size < min_size: + raise ValueError("max_size must be greater or equal than min_size") + if min_size == max_size == 0: + raise ValueError("if min_size is 0 max_size must be greater or than 0") + + return min_size, max_size + + def _check_open(self) -> None: + if self._closed and self._opened: + raise e.OperationalError( + "pool has already been opened/closed and cannot be reused" + ) + + def _check_open_getconn(self) -> None: + if self._closed: + if self._opened: + raise PoolClosed(f"the pool {self.name!r} is already closed") + else: + raise PoolClosed(f"the pool {self.name!r} is not open yet") + + def _check_pool_putconn(self, conn: ConnectionType) -> None: + pool = getattr(conn, "_pool", None) + if pool is self: + return + + if pool: + msg = f"it comes from pool {pool.name!r}" + else: + msg = "it doesn't come from any pool" + raise ValueError( + f"can't return connection to pool {self.name!r}, {msg}: {conn}" + ) + + def get_stats(self) -> Dict[str, int]: + """ + Return current stats about the pool usage. + """ + rv = dict(self._stats) + rv.update(self._get_measures()) + return rv + + def pop_stats(self) -> Dict[str, int]: + """ + Return current stats about the pool usage. + + After the call, all the counters are reset to zero. + """ + stats, self._stats = self._stats, Counter() + rv = dict(stats) + rv.update(self._get_measures()) + return rv + + def _get_measures(self) -> Dict[str, int]: + """ + Return immediate measures of the pool (not counters). + """ + return { + self._POOL_MIN: self._min_size, + self._POOL_MAX: self._max_size, + self._POOL_SIZE: self._nconns, + self._POOL_AVAILABLE: len(self._pool), + } + + @classmethod + def _jitter(cls, value: float, min_pc: float, max_pc: float) -> float: + """ + Add a random value to *value* between *min_pc* and *max_pc* percent. + """ + return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc) + + def _set_connection_expiry_date(self, conn: ConnectionType) -> None: + """Set an expiry date on a connection. + + Add some randomness to avoid mass reconnection. + """ + conn._expire_at = monotonic() + self._jitter(self.max_lifetime, -0.05, 0.0) + + +class ConnectionAttempt: + """Keep the state of a connection attempt.""" + + INITIAL_DELAY = 1.0 + DELAY_JITTER = 0.1 + DELAY_BACKOFF = 2.0 + + def __init__(self, *, reconnect_timeout: float): + self.reconnect_timeout = reconnect_timeout + self.delay = 0.0 + self.give_up_at = 0.0 + + def update_delay(self, now: float) -> None: + """Calculate how long to wait for a new connection attempt""" + if self.delay == 0.0: + self.give_up_at = now + self.reconnect_timeout + self.delay = BasePool._jitter( + self.INITIAL_DELAY, -self.DELAY_JITTER, self.DELAY_JITTER + ) + else: + self.delay *= self.DELAY_BACKOFF + + if self.delay + now > self.give_up_at: + self.delay = max(0.0, self.give_up_at - now) + + def time_to_give_up(self, now: float) -> bool: + """Return True if we are tired of trying to connect. Meh.""" + return self.give_up_at > 0.0 and now >= self.give_up_at diff --git a/psycopg_pool/psycopg_pool/errors.py b/psycopg_pool/psycopg_pool/errors.py new file mode 100644 index 0000000..9e672ad --- /dev/null +++ b/psycopg_pool/psycopg_pool/errors.py @@ -0,0 +1,25 @@ +""" +Connection pool errors. +""" + +# Copyright (C) 2021 The Psycopg Team + +from psycopg import errors as e + + +class PoolClosed(e.OperationalError): + """Attempt to get a connection from a closed pool.""" + + __module__ = "psycopg_pool" + + +class PoolTimeout(e.OperationalError): + """The pool couldn't provide a connection in acceptable time.""" + + __module__ = "psycopg_pool" + + +class TooManyRequests(e.OperationalError): + """Too many requests in the queue waiting for a connection from the pool.""" + + __module__ = "psycopg_pool" diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py new file mode 100644 index 0000000..c0a77c2 --- /dev/null +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -0,0 +1,159 @@ +""" +Psycopg null connection pools +""" + +# Copyright (C) 2022 The Psycopg Team + +import logging +import threading +from typing import Any, Optional, Tuple + +from psycopg import Connection +from psycopg.pq import TransactionStatus + +from .pool import ConnectionPool, AddConnection +from .errors import PoolTimeout, TooManyRequests +from ._compat import ConnectionTimeout + +logger = logging.getLogger("psycopg.pool") + + +class _BaseNullConnectionPool: + def __init__( + self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any + ): + super().__init__( # type: ignore[call-arg] + conninfo, *args, min_size=min_size, **kwargs + ) + + def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]: + if max_size is None: + max_size = min_size + + if min_size != 0: + raise ValueError("null pools must have min_size = 0") + if max_size < min_size: + raise ValueError("max_size must be greater or equal than min_size") + + return min_size, max_size + + def _start_initial_tasks(self) -> None: + # Null pools don't have background tasks to fill connections + # or to grow/shrink. + return + + def _maybe_grow_pool(self) -> None: + # null pools don't grow + pass + + +class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): + def wait(self, timeout: float = 30.0) -> None: + """ + Create a connection for test. + + Calling this function will verify that the connectivity with the + database works as expected. However the connection will not be stored + in the pool. + + Close the pool, and raise `PoolTimeout`, if not ready within *timeout* + sec. + """ + self._check_open_getconn() + + with self._lock: + assert not self._pool_full_event + self._pool_full_event = threading.Event() + + logger.info("waiting for pool %r initialization", self.name) + self.run_task(AddConnection(self)) + if not self._pool_full_event.wait(timeout): + self.close() # stop all the threads + raise PoolTimeout(f"pool initialization incomplete after {timeout} sec") + + with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[Connection[Any]]: + conn: Optional[Connection[Any]] = None + if self.max_size == 0 or self._nconns < self.max_size: + # Create a new connection for the client + try: + conn = self._connect(timeout=timeout) + except ConnectionTimeout as ex: + raise PoolTimeout(str(ex)) from None + self._nconns += 1 + + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + def _maybe_close_connection(self, conn: Connection[Any]) -> bool: + with self._lock: + if not self._closed and self._waiting: + return False + + conn._pool = None + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + conn.close() + self._nconns -= 1 + return True + + def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + """Change the size of the pool during runtime. + + Only *max_size* can be changed; *min_size* must remain 0. + """ + min_size, max_size = self._check_size(min_size, max_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + with self._lock: + self._min_size = min_size + self._max_size = max_size + + def check(self) -> None: + """No-op, as the pool doesn't have connections in its state.""" + pass + + def _add_to_pool(self, conn: Connection[Any]) -> None: + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if pos.set(conn): + break + else: + # No client waiting for a connection: close the connection + conn.close() + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is ready. + if self._pool_full_event: + self._pool_full_event.set() + else: + # The connection created by wait shouldn't decrease the + # count of the number of connection used. + self._nconns -= 1 diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py new file mode 100644 index 0000000..ae9d207 --- /dev/null +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -0,0 +1,122 @@ +""" +psycopg asynchronous null connection pool +""" + +# Copyright (C) 2022 The Psycopg Team + +import asyncio +import logging +from typing import Any, Optional + +from psycopg import AsyncConnection +from psycopg.pq import TransactionStatus + +from .errors import PoolTimeout, TooManyRequests +from ._compat import ConnectionTimeout +from .null_pool import _BaseNullConnectionPool +from .pool_async import AsyncConnectionPool, AddConnection + +logger = logging.getLogger("psycopg.pool") + + +class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): + async def wait(self, timeout: float = 30.0) -> None: + self._check_open_getconn() + + async with self._lock: + assert not self._pool_full_event + self._pool_full_event = asyncio.Event() + + logger.info("waiting for pool %r initialization", self.name) + self.run_task(AddConnection(self)) + try: + await asyncio.wait_for(self._pool_full_event.wait(), timeout) + except asyncio.TimeoutError: + await self.close() # stop all the tasks + raise PoolTimeout( + f"pool initialization incomplete after {timeout} sec" + ) from None + + async with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + async def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[AsyncConnection[Any]]: + conn: Optional[AsyncConnection[Any]] = None + if self.max_size == 0 or self._nconns < self.max_size: + # Create a new connection for the client + try: + conn = await self._connect(timeout=timeout) + except ConnectionTimeout as ex: + raise PoolTimeout(str(ex)) from None + self._nconns += 1 + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool: + # Close the connection if no client is waiting for it, or if the pool + # is closed. For extra refcare remove the pool reference from it. + # Maintain the stats. + async with self._lock: + if not self._closed and self._waiting: + return False + + conn._pool = None + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + await conn.close() + self._nconns -= 1 + return True + + async def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + min_size, max_size = self._check_size(min_size, max_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + async with self._lock: + self._min_size = min_size + self._max_size = max_size + + async def check(self) -> None: + pass + + async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + async with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if await pos.set(conn): + break + else: + # No client waiting for a connection: close the connection + await conn.close() + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is ready. + if self._pool_full_event: + self._pool_full_event.set() + else: + # The connection created by wait shouldn't decrease the + # count of the number of connection used. + self._nconns -= 1 diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py new file mode 100644 index 0000000..609d95d --- /dev/null +++ b/psycopg_pool/psycopg_pool/pool.py @@ -0,0 +1,839 @@ +""" +psycopg synchronous connection pool +""" + +# Copyright (C) 2021 The Psycopg Team + +import logging +import threading +from abc import ABC, abstractmethod +from time import monotonic +from queue import Queue, Empty +from types import TracebackType +from typing import Any, Callable, Dict, Iterator, List +from typing import Optional, Sequence, Type +from weakref import ref +from contextlib import contextmanager + +from psycopg import errors as e +from psycopg import Connection +from psycopg.pq import TransactionStatus + +from .base import ConnectionAttempt, BasePool +from .sched import Scheduler +from .errors import PoolClosed, PoolTimeout, TooManyRequests +from ._compat import Deque + +logger = logging.getLogger("psycopg.pool") + + +class ConnectionPool(BasePool[Connection[Any]]): + def __init__( + self, + conninfo: str = "", + *, + open: bool = True, + connection_class: Type[Connection[Any]] = Connection, + configure: Optional[Callable[[Connection[Any]], None]] = None, + reset: Optional[Callable[[Connection[Any]], None]] = None, + **kwargs: Any, + ): + self.connection_class = connection_class + self._configure = configure + self._reset = reset + + self._lock = threading.RLock() + self._waiting = Deque["WaitingClient"]() + + # to notify that the pool is full + self._pool_full_event: Optional[threading.Event] = None + + self._sched = Scheduler() + self._sched_runner: Optional[threading.Thread] = None + self._tasks: "Queue[MaintenanceTask]" = Queue() + self._workers: List[threading.Thread] = [] + + super().__init__(conninfo, **kwargs) + + if open: + self.open() + + def __del__(self) -> None: + # If the '_closed' property is not set we probably failed in __init__. + # Don't try anything complicated as probably it won't work. + if getattr(self, "_closed", True): + return + + self._stop_workers() + + def wait(self, timeout: float = 30.0) -> None: + """ + Wait for the pool to be full (with `min_size` connections) after creation. + + Close the pool, and raise `PoolTimeout`, if not ready within *timeout* + sec. + + Calling this method is not mandatory: you can try and use the pool + immediately after its creation. The first client will be served as soon + as a connection is ready. You can use this method if you prefer your + program to terminate in case the environment is not configured + properly, rather than trying to stay up the hardest it can. + """ + self._check_open_getconn() + + with self._lock: + assert not self._pool_full_event + if len(self._pool) >= self._min_size: + return + self._pool_full_event = threading.Event() + + logger.info("waiting for pool %r initialization", self.name) + if not self._pool_full_event.wait(timeout): + self.close() # stop all the threads + raise PoolTimeout(f"pool initialization incomplete after {timeout} sec") + + with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + @contextmanager + def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]: + """Context manager to obtain a connection from the pool. + + Return the connection immediately if available, otherwise wait up to + *timeout* or `self.timeout` seconds and throw `PoolTimeout` if a + connection is not available in time. + + Upon context exit, return the connection to the pool. Apply the normal + :ref:`connection context behaviour <with-connection>` (commit/rollback + the transaction in case of success/error). If the connection is no more + in working state, replace it with a new one. + """ + conn = self.getconn(timeout=timeout) + t0 = monotonic() + try: + with conn: + yield conn + finally: + t1 = monotonic() + self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0)) + self.putconn(conn) + + def getconn(self, timeout: Optional[float] = None) -> Connection[Any]: + """Obtain a connection from the pool. + + You should preferably use `connection()`. Use this function only if + it is not possible to use the connection as context manager. + + After using this function you *must* call a corresponding `putconn()`: + failing to do so will deplete the pool. A depleted pool is a sad pool: + you don't want a depleted pool. + """ + logger.info("connection requested from %r", self.name) + self._stats[self._REQUESTS_NUM] += 1 + + # Critical section: decide here if there's a connection ready + # or if the client needs to wait. + with self._lock: + self._check_open_getconn() + conn = self._get_ready_connection(timeout) + if not conn: + # No connection available: put the client in the waiting queue + t0 = monotonic() + pos = WaitingClient() + self._waiting.append(pos) + self._stats[self._REQUESTS_QUEUED] += 1 + + # If there is space for the pool to grow, let's do it + self._maybe_grow_pool() + + # If we are in the waiting queue, wait to be assigned a connection + # (outside the critical section, so only the waiting client is locked) + if not conn: + if timeout is None: + timeout = self.timeout + try: + conn = pos.wait(timeout=timeout) + except Exception: + self._stats[self._REQUESTS_ERRORS] += 1 + raise + finally: + t1 = monotonic() + self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0)) + + # Tell the connection it belongs to a pool to avoid closing on __exit__ + # Note that this property shouldn't be set while the connection is in + # the pool, to avoid to create a reference loop. + conn._pool = self + logger.info("connection given by %r", self.name) + return conn + + def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[Connection[Any]]: + """Return a connection, if the client deserves one.""" + conn: Optional[Connection[Any]] = None + if self._pool: + # Take a connection ready out of the pool + conn = self._pool.popleft() + if len(self._pool) < self._nconns_min: + self._nconns_min = len(self._pool) + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + def _maybe_grow_pool(self) -> None: + # Allow only one thread at time to grow the pool (or returning + # connections might be starved). + if self._nconns >= self._max_size or self._growing: + return + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self._growing = True + self.run_task(AddConnection(self, growing=True)) + + def putconn(self, conn: Connection[Any]) -> None: + """Return a connection to the loving hands of its pool. + + Use this function only paired with a `getconn()`. You don't need to use + it if you use the much more comfortable `connection()` context manager. + """ + # Quick check to discard the wrong connection + self._check_pool_putconn(conn) + + logger.info("returning connection to %r", self.name) + + if self._maybe_close_connection(conn): + return + + # Use a worker to perform eventual maintenance work in a separate thread + if self._reset: + self.run_task(ReturnConnection(self, conn)) + else: + self._return_connection(conn) + + def _maybe_close_connection(self, conn: Connection[Any]) -> bool: + """Close a returned connection if necessary. + + Return `!True if the connection was closed. + """ + # If the pool is closed just close the connection instead of returning + # it to the pool. For extra refcare remove the pool reference from it. + if not self._closed: + return False + + conn._pool = None + conn.close() + return True + + def open(self, wait: bool = False, timeout: float = 30.0) -> None: + """Open the pool by starting connecting and and accepting clients. + + If *wait* is `!False`, return immediately and let the background worker + fill the pool if `min_size` > 0. Otherwise wait up to *timeout* seconds + for the requested number of connections to be ready (see `wait()` for + details). + + It is safe to call `!open()` again on a pool already open (because the + method was already called, or because the pool context was entered, or + because the pool was initialized with *open* = `!True`) but you cannot + currently re-open a closed pool. + """ + with self._lock: + self._open() + + if wait: + self.wait(timeout=timeout) + + def _open(self) -> None: + if not self._closed: + return + + self._check_open() + + self._closed = False + self._opened = True + + self._start_workers() + self._start_initial_tasks() + + def _start_workers(self) -> None: + self._sched_runner = threading.Thread( + target=self._sched.run, + name=f"{self.name}-scheduler", + daemon=True, + ) + assert not self._workers + for i in range(self.num_workers): + t = threading.Thread( + target=self.worker, + args=(self._tasks,), + name=f"{self.name}-worker-{i}", + daemon=True, + ) + self._workers.append(t) + + # The object state is complete. Start the worker threads + self._sched_runner.start() + for t in self._workers: + t.start() + + def _start_initial_tasks(self) -> None: + # populate the pool with initial min_size connections in background + for i in range(self._nconns): + self.run_task(AddConnection(self)) + + # Schedule a task to shrink the pool if connections over min_size have + # remained unused. + self.schedule_task(ShrinkPool(self), self.max_idle) + + def close(self, timeout: float = 5.0) -> None: + """Close the pool and make it unavailable to new clients. + + All the waiting and future clients will fail to acquire a connection + with a `PoolClosed` exception. Currently used connections will not be + closed until returned to the pool. + + Wait *timeout* seconds for threads to terminate their job, if positive. + If the timeout expires the pool is closed anyway, although it may raise + some warnings on exit. + """ + if self._closed: + return + + with self._lock: + self._closed = True + logger.debug("pool %r closed", self.name) + + # Take waiting client and pool connections out of the state + waiting = list(self._waiting) + self._waiting.clear() + connections = list(self._pool) + self._pool.clear() + + # Now that the flag _closed is set, getconn will fail immediately, + # putconn will just close the returned connection. + self._stop_workers(waiting, connections, timeout) + + def _stop_workers( + self, + waiting_clients: Sequence["WaitingClient"] = (), + connections: Sequence[Connection[Any]] = (), + timeout: float = 0.0, + ) -> None: + + # Stop the scheduler + self._sched.enter(0, None) + + # Stop the worker threads + workers, self._workers = self._workers[:], [] + for i in range(len(workers)): + self.run_task(StopWorker(self)) + + # Signal to eventual clients in the queue that business is closed. + for pos in waiting_clients: + pos.fail(PoolClosed(f"the pool {self.name!r} is closed")) + + # Close the connections still in the pool + for conn in connections: + conn.close() + + # Wait for the worker threads to terminate + assert self._sched_runner is not None + sched_runner, self._sched_runner = self._sched_runner, None + if timeout > 0: + for t in [sched_runner] + workers: + if not t.is_alive(): + continue + t.join(timeout) + if t.is_alive(): + logger.warning( + "couldn't stop thread %s in pool %r within %s seconds", + t, + self.name, + timeout, + ) + + def __enter__(self) -> "ConnectionPool": + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + """Change the size of the pool during runtime.""" + min_size, max_size = self._check_size(min_size, max_size) + + ngrow = max(0, min_size - self._min_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + with self._lock: + self._min_size = min_size + self._max_size = max_size + self._nconns += ngrow + + for i in range(ngrow): + self.run_task(AddConnection(self)) + + def check(self) -> None: + """Verify the state of the connections currently in the pool. + + Test each connection: if it works return it to the pool, otherwise + dispose of it and create a new one. + """ + with self._lock: + conns = list(self._pool) + self._pool.clear() + + # Give a chance to the pool to grow if it has no connection. + # In case there are enough connection, or the pool is already + # growing, this is a no-op. + self._maybe_grow_pool() + + while conns: + conn = conns.pop() + try: + conn.execute("SELECT 1") + if conn.pgconn.transaction_status == TransactionStatus.INTRANS: + conn.rollback() + except Exception: + self._stats[self._CONNECTIONS_LOST] += 1 + logger.warning("discarding broken connection: %s", conn) + self.run_task(AddConnection(self)) + else: + self._add_to_pool(conn) + + def reconnect_failed(self) -> None: + """ + Called when reconnection failed for longer than `reconnect_timeout`. + """ + self._reconnect_failed(self) + + def run_task(self, task: "MaintenanceTask") -> None: + """Run a maintenance task in a worker thread.""" + self._tasks.put_nowait(task) + + def schedule_task(self, task: "MaintenanceTask", delay: float) -> None: + """Run a maintenance task in a worker thread in the future.""" + self._sched.enter(delay, task.tick) + + _WORKER_TIMEOUT = 60.0 + + @classmethod + def worker(cls, q: "Queue[MaintenanceTask]") -> None: + """Runner to execute pending maintenance task. + + The function is designed to run as a separate thread. + + Block on the queue *q*, run a task received. Finish running if a + StopWorker is received. + """ + # Don't make all the workers time out at the same moment + timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1) + while True: + # Use a timeout to make the wait interruptible + try: + task = q.get(timeout=timeout) + except Empty: + continue + + if isinstance(task, StopWorker): + logger.debug( + "terminating working thread %s", + threading.current_thread().name, + ) + return + + # Run the task. Make sure don't die in the attempt. + try: + task.run() + except Exception as ex: + logger.warning( + "task run %s failed: %s: %s", + task, + ex.__class__.__name__, + ex, + ) + + def _connect(self, timeout: Optional[float] = None) -> Connection[Any]: + """Return a new connection configured for the pool.""" + self._stats[self._CONNECTIONS_NUM] += 1 + kwargs = self.kwargs + if timeout: + kwargs = kwargs.copy() + kwargs["connect_timeout"] = max(round(timeout), 1) + t0 = monotonic() + try: + conn: Connection[Any] + conn = self.connection_class.connect(self.conninfo, **kwargs) + except Exception: + self._stats[self._CONNECTIONS_ERRORS] += 1 + raise + else: + t1 = monotonic() + self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0)) + + conn._pool = self + + if self._configure: + self._configure(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by configure function" + f" {self._configure}: discarded" + ) + + # Set an expiry date, with some randomness to avoid mass reconnection + self._set_connection_expiry_date(conn) + return conn + + def _add_connection( + self, attempt: Optional[ConnectionAttempt], growing: bool = False + ) -> None: + """Try to connect and add the connection to the pool. + + If failed, reschedule a new attempt in the future for a few times, then + give up, decrease the pool connections number and call + `self.reconnect_failed()`. + + """ + now = monotonic() + if not attempt: + attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout) + + try: + conn = self._connect() + except Exception as ex: + logger.warning(f"error connecting in {self.name!r}: {ex}") + if attempt.time_to_give_up(now): + logger.warning( + "reconnection attempt in pool %r failed after %s sec", + self.name, + self.reconnect_timeout, + ) + with self._lock: + self._nconns -= 1 + # If we have given up with a growing attempt, allow a new one. + if growing and self._growing: + self._growing = False + self.reconnect_failed() + else: + attempt.update_delay(now) + self.schedule_task( + AddConnection(self, attempt, growing=growing), + attempt.delay, + ) + return + + logger.info("adding new connection to the pool") + self._add_to_pool(conn) + if growing: + with self._lock: + # Keep on growing if the pool is not full yet, or if there are + # clients waiting and the pool can extend. + if self._nconns < self._min_size or ( + self._nconns < self._max_size and self._waiting + ): + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self.run_task(AddConnection(self, growing=True)) + else: + self._growing = False + + def _return_connection(self, conn: Connection[Any]) -> None: + """ + Return a connection to the pool after usage. + """ + self._reset_connection(conn) + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + # Connection no more in working state: create a new one. + self.run_task(AddConnection(self)) + logger.warning("discarding closed connection: %s", conn) + return + + # Check if the connection is past its best before date + if conn._expire_at <= monotonic(): + self.run_task(AddConnection(self)) + logger.info("discarding expired connection") + conn.close() + return + + self._add_to_pool(conn) + + def _add_to_pool(self, conn: Connection[Any]) -> None: + """ + Add a connection to the pool. + + The connection can be a fresh one or one already used in the pool. + + If a client is already waiting for a connection pass it on, otherwise + put it back into the pool + """ + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if pos.set(conn): + break + else: + # No client waiting for a connection: put it back into the pool + self._pool.append(conn) + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is full. + if self._pool_full_event and len(self._pool) >= self._min_size: + self._pool_full_event.set() + + def _reset_connection(self, conn: Connection[Any]) -> None: + """ + Bring a connection to IDLE state or close it. + """ + status = conn.pgconn.transaction_status + if status == TransactionStatus.IDLE: + pass + + elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR): + # Connection returned with an active transaction + logger.warning("rolling back returned connection: %s", conn) + try: + conn.rollback() + except Exception as ex: + logger.warning( + "rollback failed: %s: %s. Discarding connection %s", + ex.__class__.__name__, + ex, + conn, + ) + conn.close() + + elif status == TransactionStatus.ACTIVE: + # Connection returned during an operation. Bad... just close it. + logger.warning("closing returned connection: %s", conn) + conn.close() + + if not conn.closed and self._reset: + try: + self._reset(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by reset function" + f" {self._reset}: discarded" + ) + except Exception as ex: + logger.warning(f"error resetting connection: {ex}") + conn.close() + + def _shrink_pool(self) -> None: + to_close: Optional[Connection[Any]] = None + + with self._lock: + # Reset the min number of connections used + nconns_min = self._nconns_min + self._nconns_min = len(self._pool) + + # If the pool can shrink and connections were unused, drop one + if self._nconns > self._min_size and nconns_min > 0: + to_close = self._pool.popleft() + self._nconns -= 1 + self._nconns_min -= 1 + + if to_close: + logger.info( + "shrinking pool %r to %s because %s unused connections" + " in the last %s sec", + self.name, + self._nconns, + nconns_min, + self.max_idle, + ) + to_close.close() + + def _get_measures(self) -> Dict[str, int]: + rv = super()._get_measures() + rv[self._REQUESTS_WAITING] = len(self._waiting) + return rv + + +class WaitingClient: + """A position in a queue for a client waiting for a connection.""" + + __slots__ = ("conn", "error", "_cond") + + def __init__(self) -> None: + self.conn: Optional[Connection[Any]] = None + self.error: Optional[Exception] = None + + # The WaitingClient behaves in a way similar to an Event, but we need + # to notify reliably the flagger that the waiter has "accepted" the + # message and it hasn't timed out yet, otherwise the pool may give a + # connection to a client that has already timed out getconn(), which + # will be lost. + self._cond = threading.Condition() + + def wait(self, timeout: float) -> Connection[Any]: + """Wait for a connection to be set and return it. + + Raise an exception if the wait times out or if fail() is called. + """ + with self._cond: + if not (self.conn or self.error): + if not self._cond.wait(timeout): + self.error = PoolTimeout( + f"couldn't get a connection after {timeout} sec" + ) + + if self.conn: + return self.conn + else: + assert self.error + raise self.error + + def set(self, conn: Connection[Any]) -> bool: + """Signal the client waiting that a connection is ready. + + Return True if the client has "accepted" the connection, False + otherwise (typically because wait() has timed out). + """ + with self._cond: + if self.conn or self.error: + return False + + self.conn = conn + self._cond.notify_all() + return True + + def fail(self, error: Exception) -> bool: + """Signal the client that, alas, they won't have a connection today. + + Return True if the client has "accepted" the error, False otherwise + (typically because wait() has timed out). + """ + with self._cond: + if self.conn or self.error: + return False + + self.error = error + self._cond.notify_all() + return True + + +class MaintenanceTask(ABC): + """A task to run asynchronously to maintain the pool state.""" + + def __init__(self, pool: "ConnectionPool"): + self.pool = ref(pool) + + def __repr__(self) -> str: + pool = self.pool() + name = repr(pool.name) if pool else "<pool is gone>" + return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>" + + def run(self) -> None: + """Run the task. + + This usually happens in a worker thread. Call the concrete _run() + implementation, if the pool is still alive. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task run discarded: %s", self) + return + + logger.debug("task running in %s: %s", threading.current_thread().name, self) + self._run(pool) + + def tick(self) -> None: + """Run the scheduled task + + This function is called by the scheduler thread. Use a worker to + run the task for real in order to free the scheduler immediately. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task tick discarded: %s", self) + return + + pool.run_task(self) + + @abstractmethod + def _run(self, pool: "ConnectionPool") -> None: + ... + + +class StopWorker(MaintenanceTask): + """Signal the maintenance thread to terminate.""" + + def _run(self, pool: "ConnectionPool") -> None: + pass + + +class AddConnection(MaintenanceTask): + def __init__( + self, + pool: "ConnectionPool", + attempt: Optional["ConnectionAttempt"] = None, + growing: bool = False, + ): + super().__init__(pool) + self.attempt = attempt + self.growing = growing + + def _run(self, pool: "ConnectionPool") -> None: + pool._add_connection(self.attempt, growing=self.growing) + + +class ReturnConnection(MaintenanceTask): + """Clean up and return a connection to the pool.""" + + def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"): + super().__init__(pool) + self.conn = conn + + def _run(self, pool: "ConnectionPool") -> None: + pool._return_connection(self.conn) + + +class ShrinkPool(MaintenanceTask): + """If the pool can shrink, remove one connection. + + Re-schedule periodically and also reset the minimum number of connections + in the pool. + """ + + def _run(self, pool: "ConnectionPool") -> None: + # Reschedule the task now so that in case of any error we don't lose + # the periodic run. + pool.schedule_task(self, pool.max_idle) + pool._shrink_pool() diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py new file mode 100644 index 0000000..0ea6e9a --- /dev/null +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -0,0 +1,784 @@ +""" +psycopg asynchronous connection pool +""" + +# Copyright (C) 2021 The Psycopg Team + +import asyncio +import logging +from abc import ABC, abstractmethod +from time import monotonic +from types import TracebackType +from typing import Any, AsyncIterator, Awaitable, Callable +from typing import Dict, List, Optional, Sequence, Type +from weakref import ref +from contextlib import asynccontextmanager + +from psycopg import errors as e +from psycopg import AsyncConnection +from psycopg.pq import TransactionStatus + +from .base import ConnectionAttempt, BasePool +from .sched import AsyncScheduler +from .errors import PoolClosed, PoolTimeout, TooManyRequests +from ._compat import Task, create_task, Deque + +logger = logging.getLogger("psycopg.pool") + + +class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): + def __init__( + self, + conninfo: str = "", + *, + open: bool = True, + connection_class: Type[AsyncConnection[Any]] = AsyncConnection, + configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, + reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, + **kwargs: Any, + ): + self.connection_class = connection_class + self._configure = configure + self._reset = reset + + # asyncio objects, created on open to attach them to the right loop. + self._lock: asyncio.Lock + self._sched: AsyncScheduler + self._tasks: "asyncio.Queue[MaintenanceTask]" + + self._waiting = Deque["AsyncClient"]() + + # to notify that the pool is full + self._pool_full_event: Optional[asyncio.Event] = None + + self._sched_runner: Optional[Task[None]] = None + self._workers: List[Task[None]] = [] + + super().__init__(conninfo, **kwargs) + + if open: + self._open() + + async def wait(self, timeout: float = 30.0) -> None: + self._check_open_getconn() + + async with self._lock: + assert not self._pool_full_event + if len(self._pool) >= self._min_size: + return + self._pool_full_event = asyncio.Event() + + logger.info("waiting for pool %r initialization", self.name) + try: + await asyncio.wait_for(self._pool_full_event.wait(), timeout) + except asyncio.TimeoutError: + await self.close() # stop all the tasks + raise PoolTimeout( + f"pool initialization incomplete after {timeout} sec" + ) from None + + async with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + @asynccontextmanager + async def connection( + self, timeout: Optional[float] = None + ) -> AsyncIterator[AsyncConnection[Any]]: + conn = await self.getconn(timeout=timeout) + t0 = monotonic() + try: + async with conn: + yield conn + finally: + t1 = monotonic() + self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0)) + await self.putconn(conn) + + async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]: + logger.info("connection requested from %r", self.name) + self._stats[self._REQUESTS_NUM] += 1 + + self._check_open_getconn() + + # Critical section: decide here if there's a connection ready + # or if the client needs to wait. + async with self._lock: + conn = await self._get_ready_connection(timeout) + if not conn: + # No connection available: put the client in the waiting queue + t0 = monotonic() + pos = AsyncClient() + self._waiting.append(pos) + self._stats[self._REQUESTS_QUEUED] += 1 + + # If there is space for the pool to grow, let's do it + self._maybe_grow_pool() + + # If we are in the waiting queue, wait to be assigned a connection + # (outside the critical section, so only the waiting client is locked) + if not conn: + if timeout is None: + timeout = self.timeout + try: + conn = await pos.wait(timeout=timeout) + except Exception: + self._stats[self._REQUESTS_ERRORS] += 1 + raise + finally: + t1 = monotonic() + self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0)) + + # Tell the connection it belongs to a pool to avoid closing on __exit__ + # Note that this property shouldn't be set while the connection is in + # the pool, to avoid to create a reference loop. + conn._pool = self + logger.info("connection given by %r", self.name) + return conn + + async def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[AsyncConnection[Any]]: + conn: Optional[AsyncConnection[Any]] = None + if self._pool: + # Take a connection ready out of the pool + conn = self._pool.popleft() + if len(self._pool) < self._nconns_min: + self._nconns_min = len(self._pool) + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + def _maybe_grow_pool(self) -> None: + # Allow only one task at time to grow the pool (or returning + # connections might be starved). + if self._nconns < self._max_size and not self._growing: + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self._growing = True + self.run_task(AddConnection(self, growing=True)) + + async def putconn(self, conn: AsyncConnection[Any]) -> None: + self._check_pool_putconn(conn) + + logger.info("returning connection to %r", self.name) + if await self._maybe_close_connection(conn): + return + + # Use a worker to perform eventual maintenance work in a separate task + if self._reset: + self.run_task(ReturnConnection(self, conn)) + else: + await self._return_connection(conn) + + async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool: + # If the pool is closed just close the connection instead of returning + # it to the pool. For extra refcare remove the pool reference from it. + if not self._closed: + return False + + conn._pool = None + await conn.close() + return True + + async def open(self, wait: bool = False, timeout: float = 30.0) -> None: + # Make sure the lock is created after there is an event loop + try: + self._lock + except AttributeError: + self._lock = asyncio.Lock() + + async with self._lock: + self._open() + + if wait: + await self.wait(timeout=timeout) + + def _open(self) -> None: + if not self._closed: + return + + # Throw a RuntimeError if the pool is open outside a running loop. + asyncio.get_running_loop() + + self._check_open() + + # Create these objects now to attach them to the right loop. + # See #219 + self._tasks = asyncio.Queue() + self._sched = AsyncScheduler() + # This has been most likely, but not necessarily, created in `open()`. + try: + self._lock + except AttributeError: + self._lock = asyncio.Lock() + + self._closed = False + self._opened = True + + self._start_workers() + self._start_initial_tasks() + + def _start_workers(self) -> None: + self._sched_runner = create_task( + self._sched.run(), name=f"{self.name}-scheduler" + ) + for i in range(self.num_workers): + t = create_task( + self.worker(self._tasks), + name=f"{self.name}-worker-{i}", + ) + self._workers.append(t) + + def _start_initial_tasks(self) -> None: + # populate the pool with initial min_size connections in background + for i in range(self._nconns): + self.run_task(AddConnection(self)) + + # Schedule a task to shrink the pool if connections over min_size have + # remained unused. + self.run_task(Schedule(self, ShrinkPool(self), self.max_idle)) + + async def close(self, timeout: float = 5.0) -> None: + if self._closed: + return + + async with self._lock: + self._closed = True + logger.debug("pool %r closed", self.name) + + # Take waiting client and pool connections out of the state + waiting = list(self._waiting) + self._waiting.clear() + connections = list(self._pool) + self._pool.clear() + + # Now that the flag _closed is set, getconn will fail immediately, + # putconn will just close the returned connection. + await self._stop_workers(waiting, connections, timeout) + + async def _stop_workers( + self, + waiting_clients: Sequence["AsyncClient"] = (), + connections: Sequence[AsyncConnection[Any]] = (), + timeout: float = 0.0, + ) -> None: + # Stop the scheduler + await self._sched.enter(0, None) + + # Stop the worker tasks + workers, self._workers = self._workers[:], [] + for w in workers: + self.run_task(StopWorker(self)) + + # Signal to eventual clients in the queue that business is closed. + for pos in waiting_clients: + await pos.fail(PoolClosed(f"the pool {self.name!r} is closed")) + + # Close the connections still in the pool + for conn in connections: + await conn.close() + + # Wait for the worker tasks to terminate + assert self._sched_runner is not None + sched_runner, self._sched_runner = self._sched_runner, None + wait = asyncio.gather(sched_runner, *workers) + try: + if timeout > 0: + await asyncio.wait_for(asyncio.shield(wait), timeout=timeout) + else: + await wait + except asyncio.TimeoutError: + logger.warning( + "couldn't stop pool %r tasks within %s seconds", + self.name, + timeout, + ) + + async def __aenter__(self) -> "AsyncConnectionPool": + await self.open() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + async def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + min_size, max_size = self._check_size(min_size, max_size) + + ngrow = max(0, min_size - self._min_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + async with self._lock: + self._min_size = min_size + self._max_size = max_size + self._nconns += ngrow + + for i in range(ngrow): + self.run_task(AddConnection(self)) + + async def check(self) -> None: + async with self._lock: + conns = list(self._pool) + self._pool.clear() + + # Give a chance to the pool to grow if it has no connection. + # In case there are enough connection, or the pool is already + # growing, this is a no-op. + self._maybe_grow_pool() + + while conns: + conn = conns.pop() + try: + await conn.execute("SELECT 1") + if conn.pgconn.transaction_status == TransactionStatus.INTRANS: + await conn.rollback() + except Exception: + self._stats[self._CONNECTIONS_LOST] += 1 + logger.warning("discarding broken connection: %s", conn) + self.run_task(AddConnection(self)) + else: + await self._add_to_pool(conn) + + def reconnect_failed(self) -> None: + """ + Called when reconnection failed for longer than `reconnect_timeout`. + """ + self._reconnect_failed(self) + + def run_task(self, task: "MaintenanceTask") -> None: + """Run a maintenance task in a worker.""" + self._tasks.put_nowait(task) + + async def schedule_task(self, task: "MaintenanceTask", delay: float) -> None: + """Run a maintenance task in a worker in the future.""" + await self._sched.enter(delay, task.tick) + + @classmethod + async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None: + """Runner to execute pending maintenance task. + + The function is designed to run as a task. + + Block on the queue *q*, run a task received. Finish running if a + StopWorker is received. + """ + while True: + task = await q.get() + + if isinstance(task, StopWorker): + logger.debug("terminating working task") + return + + # Run the task. Make sure don't die in the attempt. + try: + await task.run() + except Exception as ex: + logger.warning( + "task run %s failed: %s: %s", + task, + ex.__class__.__name__, + ex, + ) + + async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]: + self._stats[self._CONNECTIONS_NUM] += 1 + kwargs = self.kwargs + if timeout: + kwargs = kwargs.copy() + kwargs["connect_timeout"] = max(round(timeout), 1) + t0 = monotonic() + try: + conn: AsyncConnection[Any] + conn = await self.connection_class.connect(self.conninfo, **kwargs) + except Exception: + self._stats[self._CONNECTIONS_ERRORS] += 1 + raise + else: + t1 = monotonic() + self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0)) + + conn._pool = self + + if self._configure: + await self._configure(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by configure function" + f" {self._configure}: discarded" + ) + + # Set an expiry date, with some randomness to avoid mass reconnection + self._set_connection_expiry_date(conn) + return conn + + async def _add_connection( + self, attempt: Optional[ConnectionAttempt], growing: bool = False + ) -> None: + """Try to connect and add the connection to the pool. + + If failed, reschedule a new attempt in the future for a few times, then + give up, decrease the pool connections number and call + `self.reconnect_failed()`. + + """ + now = monotonic() + if not attempt: + attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout) + + try: + conn = await self._connect() + except Exception as ex: + logger.warning(f"error connecting in {self.name!r}: {ex}") + if attempt.time_to_give_up(now): + logger.warning( + "reconnection attempt in pool %r failed after %s sec", + self.name, + self.reconnect_timeout, + ) + async with self._lock: + self._nconns -= 1 + # If we have given up with a growing attempt, allow a new one. + if growing and self._growing: + self._growing = False + self.reconnect_failed() + else: + attempt.update_delay(now) + await self.schedule_task( + AddConnection(self, attempt, growing=growing), + attempt.delay, + ) + return + + logger.info("adding new connection to the pool") + await self._add_to_pool(conn) + if growing: + async with self._lock: + # Keep on growing if the pool is not full yet, or if there are + # clients waiting and the pool can extend. + if self._nconns < self._min_size or ( + self._nconns < self._max_size and self._waiting + ): + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self.run_task(AddConnection(self, growing=True)) + else: + self._growing = False + + async def _return_connection(self, conn: AsyncConnection[Any]) -> None: + """ + Return a connection to the pool after usage. + """ + await self._reset_connection(conn) + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + # Connection no more in working state: create a new one. + self.run_task(AddConnection(self)) + logger.warning("discarding closed connection: %s", conn) + return + + # Check if the connection is past its best before date + if conn._expire_at <= monotonic(): + self.run_task(AddConnection(self)) + logger.info("discarding expired connection") + await conn.close() + return + + await self._add_to_pool(conn) + + async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: + """ + Add a connection to the pool. + + The connection can be a fresh one or one already used in the pool. + + If a client is already waiting for a connection pass it on, otherwise + put it back into the pool + """ + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + async with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if await pos.set(conn): + break + else: + # No client waiting for a connection: put it back into the pool + self._pool.append(conn) + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is full. + if self._pool_full_event and len(self._pool) >= self._min_size: + self._pool_full_event.set() + + async def _reset_connection(self, conn: AsyncConnection[Any]) -> None: + """ + Bring a connection to IDLE state or close it. + """ + status = conn.pgconn.transaction_status + if status == TransactionStatus.IDLE: + pass + + elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR): + # Connection returned with an active transaction + logger.warning("rolling back returned connection: %s", conn) + try: + await conn.rollback() + except Exception as ex: + logger.warning( + "rollback failed: %s: %s. Discarding connection %s", + ex.__class__.__name__, + ex, + conn, + ) + await conn.close() + + elif status == TransactionStatus.ACTIVE: + # Connection returned during an operation. Bad... just close it. + logger.warning("closing returned connection: %s", conn) + await conn.close() + + if not conn.closed and self._reset: + try: + await self._reset(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by reset function" + f" {self._reset}: discarded" + ) + except Exception as ex: + logger.warning(f"error resetting connection: {ex}") + await conn.close() + + async def _shrink_pool(self) -> None: + to_close: Optional[AsyncConnection[Any]] = None + + async with self._lock: + # Reset the min number of connections used + nconns_min = self._nconns_min + self._nconns_min = len(self._pool) + + # If the pool can shrink and connections were unused, drop one + if self._nconns > self._min_size and nconns_min > 0: + to_close = self._pool.popleft() + self._nconns -= 1 + self._nconns_min -= 1 + + if to_close: + logger.info( + "shrinking pool %r to %s because %s unused connections" + " in the last %s sec", + self.name, + self._nconns, + nconns_min, + self.max_idle, + ) + await to_close.close() + + def _get_measures(self) -> Dict[str, int]: + rv = super()._get_measures() + rv[self._REQUESTS_WAITING] = len(self._waiting) + return rv + + +class AsyncClient: + """A position in a queue for a client waiting for a connection.""" + + __slots__ = ("conn", "error", "_cond") + + def __init__(self) -> None: + self.conn: Optional[AsyncConnection[Any]] = None + self.error: Optional[Exception] = None + + # The AsyncClient behaves in a way similar to an Event, but we need + # to notify reliably the flagger that the waiter has "accepted" the + # message and it hasn't timed out yet, otherwise the pool may give a + # connection to a client that has already timed out getconn(), which + # will be lost. + self._cond = asyncio.Condition() + + async def wait(self, timeout: float) -> AsyncConnection[Any]: + """Wait for a connection to be set and return it. + + Raise an exception if the wait times out or if fail() is called. + """ + async with self._cond: + if not (self.conn or self.error): + try: + await asyncio.wait_for(self._cond.wait(), timeout) + except asyncio.TimeoutError: + self.error = PoolTimeout( + f"couldn't get a connection after {timeout} sec" + ) + + if self.conn: + return self.conn + else: + assert self.error + raise self.error + + async def set(self, conn: AsyncConnection[Any]) -> bool: + """Signal the client waiting that a connection is ready. + + Return True if the client has "accepted" the connection, False + otherwise (typically because wait() has timed out). + """ + async with self._cond: + if self.conn or self.error: + return False + + self.conn = conn + self._cond.notify_all() + return True + + async def fail(self, error: Exception) -> bool: + """Signal the client that, alas, they won't have a connection today. + + Return True if the client has "accepted" the error, False otherwise + (typically because wait() has timed out). + """ + async with self._cond: + if self.conn or self.error: + return False + + self.error = error + self._cond.notify_all() + return True + + +class MaintenanceTask(ABC): + """A task to run asynchronously to maintain the pool state.""" + + def __init__(self, pool: "AsyncConnectionPool"): + self.pool = ref(pool) + + def __repr__(self) -> str: + pool = self.pool() + name = repr(pool.name) if pool else "<pool is gone>" + return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>" + + async def run(self) -> None: + """Run the task. + + This usually happens in a worker. Call the concrete _run() + implementation, if the pool is still alive. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task run discarded: %s", self) + return + + await self._run(pool) + + async def tick(self) -> None: + """Run the scheduled task + + This function is called by the scheduler task. Use a worker to + run the task for real in order to free the scheduler immediately. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task tick discarded: %s", self) + return + + pool.run_task(self) + + @abstractmethod + async def _run(self, pool: "AsyncConnectionPool") -> None: + ... + + +class StopWorker(MaintenanceTask): + """Signal the maintenance worker to terminate.""" + + async def _run(self, pool: "AsyncConnectionPool") -> None: + pass + + +class AddConnection(MaintenanceTask): + def __init__( + self, + pool: "AsyncConnectionPool", + attempt: Optional["ConnectionAttempt"] = None, + growing: bool = False, + ): + super().__init__(pool) + self.attempt = attempt + self.growing = growing + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool._add_connection(self.attempt, growing=self.growing) + + +class ReturnConnection(MaintenanceTask): + """Clean up and return a connection to the pool.""" + + def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"): + super().__init__(pool) + self.conn = conn + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool._return_connection(self.conn) + + +class ShrinkPool(MaintenanceTask): + """If the pool can shrink, remove one connection. + + Re-schedule periodically and also reset the minimum number of connections + in the pool. + """ + + async def _run(self, pool: "AsyncConnectionPool") -> None: + # Reschedule the task now so that in case of any error we don't lose + # the periodic run. + await pool.schedule_task(self, pool.max_idle) + await pool._shrink_pool() + + +class Schedule(MaintenanceTask): + """Schedule a task in the pool scheduler. + + This task is a trampoline to allow to use a sync call (pool.run_task) + to execute an async one (pool.schedule_task). + """ + + def __init__( + self, + pool: "AsyncConnectionPool", + task: MaintenanceTask, + delay: float, + ): + super().__init__(pool) + self.task = task + self.delay = delay + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool.schedule_task(self.task, self.delay) diff --git a/psycopg_pool/psycopg_pool/py.typed b/psycopg_pool/psycopg_pool/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/psycopg_pool/psycopg_pool/py.typed diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py new file mode 100644 index 0000000..ca26007 --- /dev/null +++ b/psycopg_pool/psycopg_pool/sched.py @@ -0,0 +1,177 @@ +""" +A minimal scheduler to schedule tasks run in the future. + +Inspired to the standard library `sched.scheduler`, but designed for +multi-thread usage ground up, not as an afterthought. Tasks can be scheduled in +front of the one currently running and `Scheduler.run()` can be left running +without any task scheduled. + +Tasks are called "Task", not "Event", here, because we actually make use of +`threading.Event` and the two would be confusing. +""" + +# Copyright (C) 2021 The Psycopg Team + +import asyncio +import logging +import threading +from time import monotonic +from heapq import heappush, heappop +from typing import Any, Callable, List, Optional, NamedTuple + +logger = logging.getLogger(__name__) + + +class Task(NamedTuple): + time: float + action: Optional[Callable[[], Any]] + + def __eq__(self, other: "Task") -> Any: # type: ignore[override] + return self.time == other.time + + def __lt__(self, other: "Task") -> Any: # type: ignore[override] + return self.time < other.time + + def __le__(self, other: "Task") -> Any: # type: ignore[override] + return self.time <= other.time + + def __gt__(self, other: "Task") -> Any: # type: ignore[override] + return self.time > other.time + + def __ge__(self, other: "Task") -> Any: # type: ignore[override] + return self.time >= other.time + + +class Scheduler: + def __init__(self) -> None: + """Initialize a new instance, passing the time and delay functions.""" + self._queue: List[Task] = [] + self._lock = threading.RLock() + self._event = threading.Event() + + EMPTY_QUEUE_TIMEOUT = 600.0 + + def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue delayed in the future. + + Schedule a `!None` to stop the execution. + """ + time = monotonic() + delay + return self.enterabs(time, action) + + def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue at an absolute time. + + Schedule a `!None` to stop the execution. + """ + task = Task(time, action) + with self._lock: + heappush(self._queue, task) + first = self._queue[0] is task + + if first: + self._event.set() + + return task + + def run(self) -> None: + """Execute the events scheduled.""" + q = self._queue + while True: + with self._lock: + now = monotonic() + task = q[0] if q else None + if task: + if task.time <= now: + heappop(q) + else: + delay = task.time - now + task = None + else: + delay = self.EMPTY_QUEUE_TIMEOUT + self._event.clear() + + if task: + if not task.action: + break + try: + task.action() + except Exception as e: + logger.warning( + "scheduled task run %s failed: %s: %s", + task.action, + e.__class__.__name__, + e, + ) + else: + # Block for the expected timeout or until a new task scheduled + self._event.wait(timeout=delay) + + +class AsyncScheduler: + def __init__(self) -> None: + """Initialize a new instance, passing the time and delay functions.""" + self._queue: List[Task] = [] + self._lock = asyncio.Lock() + self._event = asyncio.Event() + + EMPTY_QUEUE_TIMEOUT = 600.0 + + async def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue delayed in the future. + + Schedule a `!None` to stop the execution. + """ + time = monotonic() + delay + return await self.enterabs(time, action) + + async def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue at an absolute time. + + Schedule a `!None` to stop the execution. + """ + task = Task(time, action) + async with self._lock: + heappush(self._queue, task) + first = self._queue[0] is task + + if first: + self._event.set() + + return task + + async def run(self) -> None: + """Execute the events scheduled.""" + q = self._queue + while True: + async with self._lock: + now = monotonic() + task = q[0] if q else None + if task: + if task.time <= now: + heappop(q) + else: + delay = task.time - now + task = None + else: + delay = self.EMPTY_QUEUE_TIMEOUT + self._event.clear() + + if task: + if not task.action: + break + try: + await task.action() + except Exception as e: + logger.warning( + "scheduled task run %s failed: %s: %s", + task.action, + e.__class__.__name__, + e, + ) + else: + # Block for the expected timeout or until a new task scheduled + try: + await asyncio.wait_for(self._event.wait(), delay) + except asyncio.TimeoutError: + pass diff --git a/psycopg_pool/psycopg_pool/version.py b/psycopg_pool/psycopg_pool/version.py new file mode 100644 index 0000000..fc99bbd --- /dev/null +++ b/psycopg_pool/psycopg_pool/version.py @@ -0,0 +1,13 @@ +""" +psycopg pool version file. +""" + +# Copyright (C) 2021 The Psycopg Team + +# Use a versioning scheme as defined in +# https://www.python.org/dev/peps/pep-0440/ + +# STOP AND READ! if you change: +__version__ = "3.1.5" +# also change: +# - `docs/news_pool.rst` to declare this version current or unreleased diff --git a/psycopg_pool/setup.cfg b/psycopg_pool/setup.cfg new file mode 100644 index 0000000..1a3274e --- /dev/null +++ b/psycopg_pool/setup.cfg @@ -0,0 +1,45 @@ +[metadata] +name = psycopg-pool +description = Connection Pool for Psycopg +url = https://psycopg.org/psycopg3/ +author = Daniele Varrazzo +author_email = daniele.varrazzo@gmail.com +license = GNU Lesser General Public License v3 (LGPLv3) + +project_urls = + Homepage = https://psycopg.org/ + Code = https://github.com/psycopg/psycopg + Issue Tracker = https://github.com/psycopg/psycopg/issues + Download = https://pypi.org/project/psycopg-pool/ + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) + Operating System :: MacOS :: MacOS X + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Database + Topic :: Database :: Front-Ends + Topic :: Software Development + Topic :: Software Development :: Libraries :: Python Modules + +long_description = file: README.rst +long_description_content_type = text/x-rst +license_files = LICENSE.txt + +[options] +python_requires = >= 3.7 +packages = find: +zip_safe = False +install_requires = + typing-extensions >= 3.10 + +[options.package_data] +psycopg_pool = py.typed diff --git a/psycopg_pool/setup.py b/psycopg_pool/setup.py new file mode 100644 index 0000000..771847d --- /dev/null +++ b/psycopg_pool/setup.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +PostgreSQL database adapter for Python - Connection Pool +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +from setuptools import setup + +# Move to the directory of setup.py: executing this file from another location +# (e.g. from the project root) will fail +here = os.path.abspath(os.path.dirname(__file__)) +if os.path.abspath(os.getcwd()) != here: + os.chdir(here) + +with open("psycopg_pool/version.py") as f: + data = f.read() + m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data) + if not m: + raise Exception(f"cannot find version in {f.name}") + version = m.group(1) + + +setup(version=version) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..14f3c9e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,55 @@ +[build-system] +requires = ["setuptools>=49.2.0", "wheel>=0.37"] +build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +filterwarnings = [ + "error", +] +testpaths=[ + "tests", +] +# Note: On Travis they these options seem to leak objects +# log_format = "%(asctime)s.%(msecs)03d %(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s" +# log_level = "DEBUG" + +[tool.coverage.run] +source = [ + "psycopg/psycopg", + "psycopg_pool/psycopg_pool", +] +[tool.coverage.report] +exclude_lines = [ + "if TYPE_CHECKING:", + '\.\.\.$', +] + +[tool.mypy] +files = [ + "psycopg/psycopg", + "psycopg_pool/psycopg_pool", + "psycopg_c/psycopg_c", + "tests", +] +warn_unused_ignores = true +show_error_codes = true +disable_bytearray_promotion = true +disable_memoryview_promotion = true +strict = true + +[[tool.mypy.overrides]] +module = [ + "shapely.*", +] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "uvloop" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "tests.*" +check_untyped_defs = true +disallow_untyped_defs = false +disallow_untyped_calls = false diff --git a/tests/README.rst b/tests/README.rst new file mode 100644 index 0000000..63c7238 --- /dev/null +++ b/tests/README.rst @@ -0,0 +1,94 @@ +psycopg test suite +=================== + +Quick version +------------- + +To run tests on the current code you can install the `test` extra of the +package, specify a connection string in the `PSYCOPG_TEST_DSN` env var to +connect to a test database, and run ``pytest``:: + + $ pip install -e "psycopg[test]" + $ export PSYCOPG_TEST_DSN="host=localhost dbname=psycopg_test" + $ pytest + + +Test options +------------ + +- The tests output header shows additional psycopg related information, + on top of the one normally displayed by ``pytest`` and the extensions used:: + + $ pytest + ========================= test session starts ========================= + platform linux -- Python 3.8.5, pytest-6.0.2, py-1.10.0, pluggy-0.13.1 + Using --randomly-seed=2416596601 + libpq available: 130002 + libpq wrapper implementation: c + + +- By default the tests run using the ``pq`` implementation that psycopg would + choose (the C module if installed, else the Python one). In order to test a + different implementation, use the normal `pq module selection mechanism`__ + of the ``PSYCOPG_IMPL`` env var:: + + $ PSYCOPG_IMPL=python pytest + ========================= test session starts ========================= + [...] + libpq available: 130002 + libpq wrapper implementation: python + + .. __: https://www.psycopg.org/psycopg/docs/api/pq.html#pq-module-implementations + + +- Slow tests have a ``slow`` marker which can be selected to reduce test + runtime to a few seconds only. Please add a ``@pytest.mark.slow`` marker to + any test needing an arbitrary wait. At the time of writing:: + + $ pytest + ========================= test session starts ========================= + [...] + ======= 1983 passed, 3 skipped, 110 xfailed in 78.21s (0:01:18) ======= + + $ pytest -m "not slow" + ========================= test session starts ========================= + [...] + ==== 1877 passed, 2 skipped, 169 deselected, 48 xfailed in 13.47s ===== + +- ``pytest`` option ``--pq-trace={TRACEFILE,STDERR}`` can be used to capture + libpq trace. When using ``stderr``, the output will only be shown for + failing or in-error tests, unless ``-s/--capture=no`` option is used. + +- ``pytest`` option ``--pq-debug`` can be used to log access to libpq's + ``PGconn`` functions. + + +Testing in docker +----------------- + +Useful to test different Python versions without installing them. Can be used +to replicate GitHub actions failures, specifying the ``--randomly-seed`` used +in the test run. The following ``PG*`` env vars are an example to adjust the +test dsn in order to connect to a database running on the docker host: specify +a set of env vars working for your setup:: + + $ docker run -ti --rm --volume `pwd`:/src --workdir /src \ + -e PSYCOPG_TEST_DSN -e PGHOST=172.17.0.1 -e PGUSER=`whoami` \ + python:3.7 bash + + # pip install -e "./psycopg[test]" ./psycopg_pool ./psycopg_c + # pytest + + +Testing with CockroachDB +======================== + +You can run CRDB in a docker container using:: + + docker run -p 26257:26257 --name crdb --rm \ + cockroachdb/cockroach:v22.1.3 start-single-node --insecure + +And use the following connection string to run the tests:: + + export PSYCOPG_TEST_DSN="host=localhost port=26257 user=root dbname=defaultdb" + pytest ... diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/adapters_example.py b/tests/adapters_example.py new file mode 100644 index 0000000..a184e6a --- /dev/null +++ b/tests/adapters_example.py @@ -0,0 +1,45 @@ +from typing import Optional + +from psycopg import pq +from psycopg.abc import Dumper, Loader, AdaptContext, PyFormat, Buffer + + +def f() -> None: + d: Dumper = MyStrDumper(str, None) + assert d.dump("abc") == b"abcabc" + assert d.quote("abc") == b"'abcabc'" + + lo: Loader = MyTextLoader(0, None) + assert lo.load(b"abc") == "abcabc" + + +class MyStrDumper: + format = pq.Format.TEXT + oid = 25 # text + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + self._cls = cls + + def dump(self, obj: str) -> bytes: + return (obj * 2).encode() + + def quote(self, obj: str) -> bytes: + value = self.dump(obj) + esc = pq.Escaping() + return b"'%s'" % esc.escape_string(value.replace(b"h", b"q")) + + def get_key(self, obj: str, format: PyFormat) -> type: + return self._cls + + def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper": + return self + + +class MyTextLoader: + format = pq.Format.TEXT + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + pass + + def load(self, data: Buffer) -> str: + return (bytes(data) * 2).decode() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..15bcf40 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,92 @@ +import sys +import asyncio +import selectors +from typing import List + +pytest_plugins = ( + "tests.fix_db", + "tests.fix_pq", + "tests.fix_mypy", + "tests.fix_faker", + "tests.fix_proxy", + "tests.fix_psycopg", + "tests.fix_crdb", + "tests.pool.fix_pool", +) + + +def pytest_configure(config): + markers = [ + "slow: this test is kinda slow (skip with -m 'not slow')", + "flakey(reason): this test may fail unpredictably')", + # There are troubles on travis with these kind of tests and I cannot + # catch the exception for my life. + "subprocess: the test import psycopg after subprocess", + "timing: the test is timing based and can fail on cheese hardware", + "dns: the test requires dnspython to run", + "postgis: the test requires the PostGIS extension to run", + ] + + for marker in markers: + config.addinivalue_line("markers", marker) + + +def pytest_addoption(parser): + parser.addoption( + "--loop", + choices=["default", "uvloop"], + default="default", + help="The asyncio loop to use for async tests.", + ) + + +def pytest_report_header(config): + rv = [] + + rv.append(f"default selector: {selectors.DefaultSelector.__name__}") + loop = config.getoption("--loop") + if loop != "default": + rv.append(f"asyncio loop: {loop}") + + return rv + + +def pytest_sessionstart(session): + # Detect if there was a segfault in the previous run. + # + # In case of segfault, pytest doesn't get a chance to write failed tests + # in the cache. As a consequence, retries would find no test failed and + # assume that all tests passed in the previous run, making the whole test pass. + cache = session.config.cache + if cache.get("segfault", False): + session.warn(Warning("Previous run resulted in segfault! Not running any test")) + session.warn(Warning("(delete '.pytest_cache/v/segfault' to clear this state)")) + raise session.Failed + cache.set("segfault", True) + + # Configure the async loop. + loop = session.config.getoption("--loop") + if loop == "uvloop": + import uvloop + + uvloop.install() + else: + assert loop == "default" + + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + +allow_fail_messages: List[str] = [] + + +def pytest_sessionfinish(session, exitstatus): + # Mark the test run successful (in the sense -weak- that we didn't segfault). + session.config.cache.set("segfault", False) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + if allow_fail_messages: + terminalreporter.section("failed tests ignored") + for msg in allow_fail_messages: + terminalreporter.line(msg) diff --git a/tests/constraints.txt b/tests/constraints.txt new file mode 100644 index 0000000..ef03ba1 --- /dev/null +++ b/tests/constraints.txt @@ -0,0 +1,32 @@ +# This is a constraint file forcing the minimum allowed version to be +# installed. +# +# https://pip.pypa.io/en/stable/user_guide/#constraints-files + +# From install_requires +backports.zoneinfo == 0.2.0 +typing-extensions == 4.1.0 + +# From the 'test' extra +mypy == 0.981 +pproxy == 2.7.0 +pytest == 6.2.5 +pytest-asyncio == 0.17.0 +pytest-cov == 3.0.0 +pytest-randomly == 3.10.0 + +# From the 'dev' extra +black == 22.3.0 +dnspython == 2.1.0 +flake8 == 4.0.0 +mypy == 0.981 +types-setuptools == 57.4.0 +wheel == 0.37 + +# From the 'docs' extra +Sphinx == 4.2.0 +furo == 2021.11.23 +sphinx-autobuild == 2021.3.14 +sphinx-autodoc-typehints == 1.12.0 +dnspython == 2.1.0 +shapely == 1.7.0 diff --git a/tests/crdb/__init__.py b/tests/crdb/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/crdb/__init__.py diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py new file mode 100644 index 0000000..ce5bacf --- /dev/null +++ b/tests/crdb/test_adapt.py @@ -0,0 +1,78 @@ +from copy import deepcopy + +import pytest + +from psycopg.crdb import adapters, CrdbConnection + +from psycopg.adapt import PyFormat, Transformer +from psycopg.types.array import ListDumper +from psycopg.postgres import types as builtins + +from ..test_adapt import MyStr, make_dumper, make_bin_dumper +from ..test_adapt import make_loader, make_bin_loader + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_return_untyped(conn, fmt_in): + # Analyze and check for changes using strings in untyped/typed contexts + cur = conn.cursor() + # Currently string are passed as text oid to CockroachDB, unlike Postgres, + # to which strings are passed as unknown. This is because CRDB doesn't + # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises + # an error. However, unlike PostgreSQL, text can be cast to any other type. + cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10]) + assert cur.fetchone() == ("hello", 10) + + cur.execute("create table testjson(data jsonb)") + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + assert cur.execute("select data from testjson").fetchone() == ({},) + + +def test_str_list_dumper_text(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.TEXT) + assert isinstance(dstr, ListDumper) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid + + +@pytest.fixture +def crdb_adapters(): + """Restore the crdb adapters after a test has changed them.""" + dumpers = deepcopy(adapters._dumpers) + dumpers_by_oid = deepcopy(adapters._dumpers_by_oid) + loaders = deepcopy(adapters._loaders) + types = list(adapters.types) + + yield None + + adapters._dumpers = dumpers + adapters._dumpers_by_oid = dumpers_by_oid + adapters._loaders = loaders + adapters.types.clear() + for t in types: + adapters.types.add(t) + + +def test_dump_global_ctx(dsn, crdb_adapters, pgconn): + adapters.register_dumper(MyStr, make_bin_dumper("gb")) + adapters.register_dumper(MyStr, make_dumper("gt")) + with CrdbConnection.connect(dsn) as conn: + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + + +def test_load_global_ctx(dsn, crdb_adapters): + adapters.register_loader("text", make_loader("gt")) + adapters.register_loader("text", make_bin_loader("gb")) + with CrdbConnection.connect(dsn) as conn: + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py new file mode 100644 index 0000000..b2a69ef --- /dev/null +++ b/tests/crdb/test_connection.py @@ -0,0 +1,86 @@ +import time +import threading + +import psycopg.crdb +from psycopg import errors as e +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb + + +def test_is_crdb(conn): + assert CrdbConnection.is_crdb(conn) + assert CrdbConnection.is_crdb(conn.pgconn) + + +def test_connect(dsn): + with CrdbConnection.connect(dsn) as conn: + assert isinstance(conn, CrdbConnection) + + with psycopg.crdb.connect(dsn) as conn: + assert isinstance(conn, CrdbConnection) + + +def test_xid(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +def test_tpc_begin(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_begin("foo") + + +def test_tpc_recover(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_recover() + + +@pytest.mark.slow +def test_broken_connection(conn): + cur = conn.cursor() + (session_id,) = cur.execute("select session_id from [show session_id]").fetchone() + with pytest.raises(psycopg.DatabaseError): + cur.execute("cancel session %s", [session_id]) + assert conn.closed + + +@pytest.mark.slow +def test_broken(conn): + (session_id,) = conn.execute("show session_id").fetchone() + with pytest.raises(psycopg.OperationalError): + conn.execute("cancel session %s", [session_id]) + + assert conn.closed + assert conn.broken + conn.close() + assert conn.closed + assert conn.broken + + +@pytest.mark.slow +def test_identify_closure(conn_cls, dsn): + with conn_cls.connect(dsn, autocommit=True) as conn: + with conn_cls.connect(dsn, autocommit=True) as conn2: + (session_id,) = conn.execute("show session_id").fetchone() + + def closer(): + time.sleep(0.2) + conn2.execute("cancel session %s", [session_id]) + + t = threading.Thread(target=closer) + t.start() + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + conn.execute("select pg_sleep(3.0)") + dt = time.time() - t0 + # CRDB seems to take not less than 1s + assert 0.2 < dt < 2 + finally: + t.join() diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py new file mode 100644 index 0000000..b568e42 --- /dev/null +++ b/tests/crdb/test_connection_async.py @@ -0,0 +1,85 @@ +import time +import asyncio + +import psycopg.crdb +from psycopg import errors as e +from psycopg.crdb import AsyncCrdbConnection +from psycopg._compat import create_task + +import pytest + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +async def test_is_crdb(aconn): + assert AsyncCrdbConnection.is_crdb(aconn) + assert AsyncCrdbConnection.is_crdb(aconn.pgconn) + + +async def test_connect(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection) + + +async def test_xid(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +async def test_tpc_begin(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_begin("foo") + + +async def test_tpc_recover(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_recover() + + +@pytest.mark.slow +async def test_broken_connection(aconn): + cur = aconn.cursor() + await cur.execute("select session_id from [show session_id]") + (session_id,) = await cur.fetchone() + with pytest.raises(psycopg.DatabaseError): + await cur.execute("cancel session %s", [session_id]) + assert aconn.closed + + +@pytest.mark.slow +async def test_broken(aconn): + cur = await aconn.execute("show session_id") + (session_id,) = await cur.fetchone() + with pytest.raises(psycopg.OperationalError): + await aconn.execute("cancel session %s", [session_id]) + + assert aconn.closed + assert aconn.broken + await aconn.close() + assert aconn.closed + assert aconn.broken + + +@pytest.mark.slow +async def test_identify_closure(aconn_cls, dsn): + async with await aconn_cls.connect(dsn) as conn: + async with await aconn_cls.connect(dsn) as conn2: + cur = await conn.execute("show session_id") + (session_id,) = await cur.fetchone() + + async def closer(): + await asyncio.sleep(0.2) + await conn2.execute("cancel session %s", [session_id]) + + t = create_task(closer()) + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + await conn.execute("select pg_sleep(3.0)") + dt = time.time() - t0 + assert 0.2 < dt < 2 + finally: + await asyncio.gather(t) diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py new file mode 100644 index 0000000..274a0c0 --- /dev/null +++ b/tests/crdb/test_conninfo.py @@ -0,0 +1,21 @@ +import pytest + +pytestmark = pytest.mark.crdb + + +def test_vendor(conn): + assert conn.info.vendor == "CockroachDB" + + +def test_server_version(conn): + assert conn.info.server_version > 200000 + + +@pytest.mark.crdb("< 22") +def test_backend_pid_pre_22(conn): + assert conn.info.backend_pid == 0 + + +@pytest.mark.crdb(">= 22") +def test_backend_pid(conn): + assert conn.info.backend_pid > 0 diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py new file mode 100644 index 0000000..b7d26aa --- /dev/null +++ b/tests/crdb/test_copy.py @@ -0,0 +1,233 @@ +import pytest +import string +from random import randrange, choice + +from psycopg import sql, errors as e +from psycopg.pq import Format +from psycopg.adapt import PyFormat +from psycopg.types.numeric import Int4 + +from ..utils import eur, gc_collect, gc_count +from ..test_copy import sample_text, sample_binary # noqa +from ..test_copy import ensure_table, sample_records +from ..test_copy import sample_tabledef as sample_tabledef_pg + +# CRDB int/serial are int8 +sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4") + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_copy_in_buffers(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_buffers_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_str(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text.decode()) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") +def test_copy_in_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + with cur.copy("copy copy_in from stdin with binary") as copy: + copy.write(sample_text.decode()) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +def test_copy_big_size_record(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write_row([data]) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.slow +def test_copy_big_size_block(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write(copy_data) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +def test_copy_in_buffers_with_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_set_types(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_binary(conn, format): + cur = conn.cursor() + ensure_table(cur, "col1 serial primary key, col2 int4, data text") + + with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy: + for row in sample_records: + copy.write_row((None, row[2])) + + data = cur.execute("select col2, data from copy_in order by 2").fetchall() + assert data == [(None, "hello"), (None, "world")] + + +@pytest.mark.crdb_skip("copy canceled") +def test_copy_in_buffers_with_py_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_allchars(conn): + cur = conn.cursor() + ensure_table(cur, "col1 int primary key, col2 int, data text") + + with cur.copy("copy copy_in from stdin") as copy: + for i in range(1, 256): + copy.write_row((i, None, chr(i))) + copy.write_row((ord(eur), None, eur)) + + data = cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ).fetchall() + assert data == [(True, True, 1, 256)] + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.crdb_skip("copy array") +def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + def work(): + with conn_cls.connect(dsn) as conn: + with conn.cursor(binary=fmt) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin {}").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL("with binary" if fmt else ""), + ) + with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + copy.write_row(row) + + cur.execute(faker.select_stmt) + recs = cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +def copyopt(format): + return "with binary" if format == Format.BINARY else "" diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py new file mode 100644 index 0000000..d5fbf50 --- /dev/null +++ b/tests/crdb/test_copy_async.py @@ -0,0 +1,235 @@ +import pytest +import string +from random import randrange, choice + +from psycopg.pq import Format +from psycopg import sql, errors as e +from psycopg.adapt import PyFormat +from psycopg.types.numeric import Int4 + +from ..utils import eur, gc_collect, gc_count +from ..test_copy import sample_text, sample_binary # noqa +from ..test_copy import sample_records +from ..test_copy_async import ensure_table +from .test_copy import sample_tabledef, copyopt + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_copy_in_buffers(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_buffers_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_str(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text.decode()) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") +async def test_copy_in_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + async with cur.copy("copy copy_in from stdin with binary") as copy: + await copy.write(sample_text.decode()) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +async def test_copy_big_size_record(aconn): + cur = aconn.cursor() + await ensure_table(cur, "id serial primary key, data text") + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write_row([data]) + + await cur.execute("select data from copy_in limit 1") + assert (await cur.fetchone())[0] == data + + +@pytest.mark.slow +async def test_copy_big_size_block(aconn): + cur = aconn.cursor() + await ensure_table(cur, "id serial primary key, data text") + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write(copy_data) + + await cur.execute("select data from copy_in limit 1") + assert (await cur.fetchone())[0] == data + + +async def test_copy_in_buffers_with_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_set_types(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_binary(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, "col1 serial primary key, col2 int4, data text") + + async with cur.copy( + f"copy copy_in (col2, data) from stdin {copyopt(format)}" + ) as copy: + for row in sample_records: + await copy.write_row((None, row[2])) + + await cur.execute("select col2, data from copy_in order by 2") + data = await cur.fetchall() + assert data == [(None, "hello"), (None, "world")] + + +@pytest.mark.crdb_skip("copy canceled") +async def test_copy_in_buffers_with_py_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_allchars(aconn): + cur = aconn.cursor() + await ensure_table(cur, "col1 int primary key, col2 int, data text") + + async with cur.copy("copy copy_in from stdin") as copy: + for i in range(1, 256): + await copy.write_row((i, None, chr(i))) + await copy.write_row((ord(eur), None, eur)) + + await cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ) + data = await cur.fetchall() + assert data == [(True, True, 1, 256)] + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.crdb_skip("copy array") +async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin {}").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL("with binary" if fmt else ""), + ) + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + await copy.write_row(row) + + await cur.execute(faker.select_stmt) + recs = await cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py new file mode 100644 index 0000000..991b084 --- /dev/null +++ b/tests/crdb/test_cursor.py @@ -0,0 +1,65 @@ +import json +import threading +from uuid import uuid4 +from queue import Queue +from typing import Any + +import pytest +from psycopg import pq, errors as e +from psycopg.rows import namedtuple_row + +pytestmark = pytest.mark.crdb + + +@pytest.fixture +def testfeed(svcconn): + name = f"test_feed_{str(uuid4()).replace('-', '')}" + svcconn.execute("set cluster setting kv.rangefeed.enabled to true") + svcconn.execute(f"create table {name} (id serial primary key, data text)") + yield name + svcconn.execute(f"drop table {name}") + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out): + conn.autocommit = True + q: "Queue[Any]" = Queue() + + def worker(): + try: + with conn_cls.connect(dsn, autocommit=True) as conn: + cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row) + try: + for row in cur.stream(f"experimental changefeed for {testfeed}"): + q.put(row) + except e.QueryCanceled: + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + q.put(None) + except Exception as ex: + q.put(ex) + + t = threading.Thread(target=worker) + t.start() + + cur = conn.cursor() + cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") + (key,) = cur.fetchone() + row = q.get(timeout=1) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} + + cur.execute(f"delete from {testfeed} where id = %s", [key]) + row = q.get(timeout=1) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] is None + + cur.execute("select query_id from [show statements] where query !~ 'show'") + (qid,) = cur.fetchone() + cur.execute("cancel query %s", [qid]) + assert cur.statusmessage == "CANCEL QUERIES 1" + + assert q.get(timeout=1) is None + t.join() diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py new file mode 100644 index 0000000..229295d --- /dev/null +++ b/tests/crdb/test_cursor_async.py @@ -0,0 +1,61 @@ +import json +import asyncio +from typing import Any +from asyncio.queues import Queue + +import pytest +from psycopg import pq, errors as e +from psycopg.rows import namedtuple_row +from psycopg._compat import create_task + +from .test_cursor import testfeed + +testfeed # fixture + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out): + await aconn.set_autocommit(True) + q: "Queue[Any]" = Queue() + + async def worker(): + try: + async with await aconn_cls.connect(dsn, autocommit=True) as conn: + cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row) + try: + async for row in cur.stream( + f"experimental changefeed for {testfeed}" + ): + q.put_nowait(row) + except e.QueryCanceled: + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + q.put_nowait(None) + except Exception as ex: + q.put_nowait(ex) + + t = create_task(worker()) + + cur = aconn.cursor() + await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") + (key,) = await cur.fetchone() + row = await asyncio.wait_for(q.get(), 1.0) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} + + await cur.execute(f"delete from {testfeed} where id = %s", [key]) + row = await asyncio.wait_for(q.get(), 1.0) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] is None + + await cur.execute("select query_id from [show statements] where query !~ 'show'") + (qid,) = await cur.fetchone() + await cur.execute("cancel query %s", [qid]) + assert cur.statusmessage == "CANCEL QUERIES 1" + + assert await asyncio.wait_for(q.get(), 1.0) is None + await asyncio.gather(t) diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py new file mode 100644 index 0000000..df43f3b --- /dev/null +++ b/tests/crdb/test_no_crdb.py @@ -0,0 +1,34 @@ +from psycopg.pq import TransactionStatus +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb("skip") + + +def test_is_crdb(conn): + assert not CrdbConnection.is_crdb(conn) + assert not CrdbConnection.is_crdb(conn.pgconn) + + +def test_tpc_on_pg_connection(conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py new file mode 100644 index 0000000..2cff0a7 --- /dev/null +++ b/tests/crdb/test_typing.py @@ -0,0 +1,49 @@ +import pytest + +from ..test_typing import _test_reveal + + +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg.crdb.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect()", + "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy): + stmts = f"obj = {conn}" + _test_reveal_crdb(stmts, type, mypy) + + +def _test_reveal_crdb(stmts, type, mypy): + stmts = f"""\ +import psycopg.crdb +{stmts} +""" + _test_reveal(stmts, type, mypy) diff --git a/tests/dbapi20.py b/tests/dbapi20.py new file mode 100644 index 0000000..c873a4e --- /dev/null +++ b/tests/dbapi20.py @@ -0,0 +1,870 @@ +#!/usr/bin/env python +# flake8: noqa +# fmt: off +''' Python DB API 2.0 driver compliance unit test suite. + + This software is Public Domain and may be used without restrictions. + + "Now we have booze and barflies entering the discussion, plus rumours of + DBAs on drugs... and I won't tell you what flashes through my mind each + time I read the subject line with 'Anal Compliance' in it. All around + this is turning out to be a thoroughly unwholesome unit test." + + -- Ian Bicking +''' + +__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $' +__version__ = '$Revision: 1.12 $'[11:-2] +__author__ = 'Stuart Bishop <stuart@stuartbishop.net>' + +import unittest +import time +import sys +from typing import Any, Dict + + +# Revision 1.12 2009/02/06 03:35:11 kf7xm +# Tested okay with Python 3.0, includes last minute patches from Mark H. +# +# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole +# Include latest changes from main branch +# Updates for py3k +# +# Revision 1.11 2005/01/02 02:41:01 zenzen +# Update author email address +# +# Revision 1.10 2003/10/09 03:14:14 zenzen +# Add test for DB API 2.0 optional extension, where database exceptions +# are exposed as attributes on the Connection object. +# +# Revision 1.9 2003/08/13 01:16:36 zenzen +# Minor tweak from Stefan Fleiter +# +# Revision 1.8 2003/04/10 00:13:25 zenzen +# Changes, as per suggestions by M.-A. Lemburg +# - Add a table prefix, to ensure namespace collisions can always be avoided +# +# Revision 1.7 2003/02/26 23:33:37 zenzen +# Break out DDL into helper functions, as per request by David Rushby +# +# Revision 1.6 2003/02/21 03:04:33 zenzen +# Stuff from Henrik Ekelund: +# added test_None +# added test_nextset & hooks +# +# Revision 1.5 2003/02/17 22:08:43 zenzen +# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize +# defaults to 1 & generic cursor.callproc test added +# +# Revision 1.4 2003/02/15 00:16:33 zenzen +# Changes, as per suggestions and bug reports by M.-A. Lemburg, +# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar +# - Class renamed +# - Now a subclass of TestCase, to avoid requiring the driver stub +# to use multiple inheritance +# - Reversed the polarity of buggy test in test_description +# - Test exception hierarchy correctly +# - self.populate is now self._populate(), so if a driver stub +# overrides self.ddl1 this change propagates +# - VARCHAR columns now have a width, which will hopefully make the +# DDL even more portible (this will be reversed if it causes more problems) +# - cursor.rowcount being checked after various execute and fetchXXX methods +# - Check for fetchall and fetchmany returning empty lists after results +# are exhausted (already checking for empty lists if select retrieved +# nothing +# - Fix bugs in test_setoutputsize_basic and test_setinputsizes +# + +class DatabaseAPI20Test(unittest.TestCase): + ''' Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. + + The 'Optional Extensions' are not yet being tested. + + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: + + from . import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] + + Don't 'from .dbapi20 import DatabaseAPI20Test', or you will + confuse the unit tester - just 'from . import dbapi20'. + ''' + + # The self.driver module. This should be the module where the 'connect' + # method is to be found + driver: Any = None + connect_args = () # List of arguments to pass to connect + connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + + ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix + ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix + xddl1 = 'drop table %sbooze' % table_prefix + xddl2 = 'drop table %sbarflys' % table_prefix + + lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + + # Some drivers may need to override these helpers, for example adding + # a 'commit' after the execute. + def executeDDL1(self,cursor): + cursor.execute(self.ddl1) + + def executeDDL2(self,cursor): + cursor.execute(self.ddl2) + + def setUp(self): + ''' self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + ''' + pass + + def tearDown(self): + ''' self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + ''' + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1,self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + + def _connect(self): + try: + return self.driver.connect( + *self.connect_args,**self.connect_kw_args + ) + except AttributeError: + self.fail("No connect method found in self.driver module") + + def test_connect(self): + con = self._connect() + con.close() + + def test_apilevel(self): + try: + # Must exist + apilevel = self.driver.apilevel + # Must equal 2.0 + self.assertEqual(apilevel,'2.0') + except AttributeError: + self.fail("Driver doesn't define apilevel") + + def test_threadsafety(self): + try: + # Must exist + threadsafety = self.driver.threadsafety + # Must be a valid value + self.failUnless(threadsafety in (0,1,2,3)) + except AttributeError: + self.fail("Driver doesn't define threadsafety") + + def test_paramstyle(self): + try: + # Must exist + paramstyle = self.driver.paramstyle + # Must be a valid value + self.failUnless(paramstyle in ( + 'qmark','numeric','named','format','pyformat' + )) + except AttributeError: + self.fail("Driver doesn't define paramstyle") + + def test_Exceptions(self): + # Make sure required exceptions exist, and are in the + # defined hierarchy. + if sys.version[0] == '3': #under Python 3 StardardError no longer exists + self.failUnless(issubclass(self.driver.Warning,Exception)) + self.failUnless(issubclass(self.driver.Error,Exception)) + else: + self.failUnless(issubclass(self.driver.Warning,StandardError)) # type: ignore[name-defined] + self.failUnless(issubclass(self.driver.Error,StandardError)) # type: ignore[name-defined] + + self.failUnless( + issubclass(self.driver.InterfaceError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.DatabaseError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.OperationalError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.IntegrityError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.InternalError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.ProgrammingError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.NotSupportedError,self.driver.Error) + ) + + def test_ExceptionsAsConnectionAttributes(self): + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. + con = self._connect() + drv = self.driver + self.failUnless(con.Warning is drv.Warning) + self.failUnless(con.Error is drv.Error) + self.failUnless(con.InterfaceError is drv.InterfaceError) + self.failUnless(con.DatabaseError is drv.DatabaseError) + self.failUnless(con.OperationalError is drv.OperationalError) + self.failUnless(con.IntegrityError is drv.IntegrityError) + self.failUnless(con.InternalError is drv.InternalError) + self.failUnless(con.ProgrammingError is drv.ProgrammingError) + self.failUnless(con.NotSupportedError is drv.NotSupportedError) + con.close() + + + def test_commit(self): + con = self._connect() + try: + # Commit must work, even if it doesn't do anything + con.commit() + finally: + con.close() + + def test_rollback(self): + con = self._connect() + # If rollback is defined, it should either work or throw + # the documented exception + if hasattr(con,'rollback'): + try: + con.rollback() + except self.driver.NotSupportedError: + pass + con.close() + + def test_cursor(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + def test_cursor_isolation(self): + con = self._connect() + try: + # Make sure cursors created from the same connection have + # the documented transaction isolation level + cur1 = con.cursor() + cur2 = con.cursor() + self.executeDDL1(cur1) + cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + cur2.execute("select name from %sbooze" % self.table_prefix) + booze = cur2.fetchall() + self.assertEqual(len(booze),1) + self.assertEqual(len(booze[0]),1) + self.assertEqual(booze[0][0],'Victoria Bitter') + finally: + con.close() + + def test_description(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.description,None, + 'cursor.description should be none after executing a ' + 'statement that can return no rows (such as DDL)' + ) + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(len(cur.description),1, + 'cursor.description describes too many columns' + ) + self.assertEqual(len(cur.description[0]),7, + 'cursor.description[x] tuples must have 7 elements' + ) + self.assertEqual(cur.description[0][0].lower(),'name', + 'cursor.description[x][0] must return column name' + ) + self.assertEqual(cur.description[0][1],self.driver.STRING, + 'cursor.description[x][1] must return column type. Got %r' + % cur.description[0][1] + ) + + # Make sure self.description gets reset + self.executeDDL2(cur) + self.assertEqual(cur.description,None, + 'cursor.description not being set to None when executing ' + 'no-result statements (eg. DDL)' + ) + finally: + con.close() + + def test_rowcount(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount should be -1 after executing no-result ' + 'statements' + ) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.failUnless(cur.rowcount in (-1,1), + 'cursor.rowcount should == number or rows inserted, or ' + 'set to -1 after executing an insert statement' + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.failUnless(cur.rowcount in (-1,1), + 'cursor.rowcount should == number of rows returned, or ' + 'set to -1 after executing a select statement' + ) + self.executeDDL2(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount not being reset to -1 after executing ' + 'no-result statements' + ) + finally: + con.close() + + lower_func = 'lower' + def test_callproc(self): + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur,'callproc'): + r = cur.callproc(self.lower_func,('FOO',)) + self.assertEqual(len(r),1) + self.assertEqual(r[0],'FOO') + r = cur.fetchall() + self.assertEqual(len(r),1,'callproc produced no result set') + self.assertEqual(len(r[0]),1, + 'callproc produced invalid result set' + ) + self.assertEqual(r[0][0],'foo', + 'callproc produced invalid results' + ) + finally: + con.close() + + def test_close(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + # cursor.execute should raise an Error if called after connection + # closed + self.assertRaises(self.driver.Error,self.executeDDL1,cur) + + # connection.commit should raise an Error if called after connection' + # closed.' + self.assertRaises(self.driver.Error,con.commit) + + # connection.close should raise an Error if called more than once + # Issue discussed on DB-SIG: consensus seem that close() should not + # raised if called on closed objects. Issue reported back to Stuart. + # self.assertRaises(self.driver.Error,con.close) + + def test_execute(self): + con = self._connect() + try: + cur = con.cursor() + self._paraminsert(cur) + finally: + con.close() + + def _paraminsert(self,cur): + self.executeDDL1(cur) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.failUnless(cur.rowcount in (-1,1)) + + if self.driver.paramstyle == 'qmark': + cur.execute( + 'insert into %sbooze values (?)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'numeric': + cur.execute( + 'insert into %sbooze values (:1)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'named': + cur.execute( + 'insert into %sbooze values (:beer)' % self.table_prefix, + {'beer':"Cooper's"} + ) + elif self.driver.paramstyle == 'format': + cur.execute( + 'insert into %sbooze values (%%s)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'pyformat': + cur.execute( + 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, + {'beer':"Cooper's"} + ) + else: + self.fail('Invalid paramstyle') + self.failUnless(cur.rowcount in (-1,1)) + + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + self.assertEqual(beers[1],"Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + + def test_executemany(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + largs = [ ("Cooper's",) , ("Boag's",) ] + margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] + if self.driver.paramstyle == 'qmark': + cur.executemany( + 'insert into %sbooze values (?)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'numeric': + cur.executemany( + 'insert into %sbooze values (:1)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'named': + cur.executemany( + 'insert into %sbooze values (:beer)' % self.table_prefix, + margs + ) + elif self.driver.paramstyle == 'format': + cur.executemany( + 'insert into %sbooze values (%%s)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'pyformat': + cur.executemany( + 'insert into %sbooze values (%%(beer)s)' % ( + self.table_prefix + ), + margs + ) + else: + self.fail('Unknown paramstyle') + self.failUnless(cur.rowcount in (-1,2), + 'insert using cursor.executemany set cursor.rowcount to ' + 'incorrect value %r' % cur.rowcount + ) + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2, + 'cursor.fetchall retrieved incorrect number of rows' + ) + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') + self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + finally: + con.close() + + def test_fetchone(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error,cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if a query retrieves ' + 'no rows' + ) + self.failUnless(cur.rowcount in (-1,0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchone() + self.assertEqual(len(r),1, + 'cursor.fetchone should have retrieved a single row' + ) + self.assertEqual(r[0],'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data' + ) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if no more rows available' + ) + self.failUnless(cur.rowcount in (-1,1)) + finally: + con.close() + + samples = [ + 'Carlton Cold', + 'Carlton Draft', + 'Mountain Goat', + 'Redback', + 'Victoria Bitter', + 'XXXX' + ] + + def _populate(self): + ''' Return a list of sql commands to setup the DB for the fetch + tests. + ''' + populate = [ + "insert into %sbooze values ('%s')" % (self.table_prefix,s) + for s in self.samples + ] + return populate + + def test_fetchmany(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchmany should raise an Error if called without + #issuing a query + self.assertRaises(self.driver.Error,cur.fetchmany,4) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() + self.assertEqual(len(r),1, + 'cursor.fetchmany retrieved incorrect number of rows, ' + 'default of arraysize is one.' + ) + cur.arraysize=10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual(len(r),3, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual(len(r),2, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence after ' + 'results are exhausted' + ) + self.failUnless(cur.rowcount in (-1,6)) + + # Same as above, using cursor.arraysize + cur.arraysize=4 + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual(len(r),4, + 'cursor.arraysize not being honoured by fetchmany' + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r),2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r),0) + self.failUnless(cur.rowcount in (-1,6)) + + cur.arraysize=6 + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.failUnless(cur.rowcount in (-1,6)) + self.assertEqual(len(rows),6) + self.assertEqual(len(rows),6) + rows = [r[0] for r in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0,6): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved by cursor.fetchmany' + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual(len(rows),0, + 'cursor.fetchmany should return an empty sequence if ' + 'called after the whole result set has been fetched' + ) + self.failUnless(cur.rowcount in (-1,6)) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence if ' + 'query retrieved no rows' + ) + self.failUnless(cur.rowcount in (-1,0)) + + finally: + con.close() + + def test_fetchall(self): + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows + self.assertRaises(self.driver.Error,cur.fetchall) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchall() + self.failUnless(cur.rowcount in (-1,len(self.samples))) + self.assertEqual(len(rows),len(self.samples), + 'cursor.fetchall did not retrieve all rows' + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'cursor.fetchall retrieved incorrect rows' + ) + rows = cur.fetchall() + self.assertEqual( + len(rows),0, + 'cursor.fetchall should return an empty list if called ' + 'after the whole result set has been fetched' + ) + self.failUnless(cur.rowcount in (-1,len(self.samples))) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + rows = cur.fetchall() + self.failUnless(cur.rowcount in (-1,0)) + self.assertEqual(len(rows),0, + 'cursor.fetchall should return an empty list if ' + 'a select query returns no rows' + ) + + finally: + con.close() + + def test_mixedfetch(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows1 = cur.fetchone() + rows23 = cur.fetchmany(2) + rows4 = cur.fetchone() + rows56 = cur.fetchall() + self.failUnless(cur.rowcount in (-1,6)) + self.assertEqual(len(rows23),2, + 'fetchmany returned incorrect number of rows' + ) + self.assertEqual(len(rows56),2, + 'fetchall returned incorrect number of rows' + ) + + rows = [rows1[0]] + rows.extend([rows23[0][0],rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows56[0][0],rows56[1][0]]) + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved or inserted' + ) + finally: + con.close() + + def help_nextset_setUp(self,cur): + ''' Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + ''' + raise NotImplementedError('Helper not implemented') + #sql=""" + # create procedure deleteme as + # begin + # select count(*) from booze + # select name from booze + # end + #""" + #cur.execute(sql) + + def help_nextset_tearDown(self,cur): + 'If cleaning up is needed after nextSetTest' + raise NotImplementedError('Helper not implemented') + #cur.execute("drop procedure deleteme") + + def test_nextset(self): + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur,'nextset'): + return + + try: + self.executeDDL1(cur) + sql=self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + cur.callproc('deleteme') + numberofrows=cur.fetchone() + assert numberofrows[0]== len(self.samples) + assert cur.nextset() + names=cur.fetchall() + assert len(names) == len(self.samples) + s=cur.nextset() + assert s is None, 'No more return sets, should return None' + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + def test_arraysize(self): + # Not much here - rest of the tests for this are in test_fetchmany + con = self._connect() + try: + cur = con.cursor() + self.failUnless(hasattr(cur,'arraysize'), + 'cursor.arraysize must be defined' + ) + finally: + con.close() + + def test_setinputsizes(self): + con = self._connect() + try: + cur = con.cursor() + cur.setinputsizes( (25,) ) + self._paraminsert(cur) # Make sure cursor still works + finally: + con.close() + + def test_setoutputsize_basic(self): + # Basic test is to make sure setoutputsize doesn't blow up + con = self._connect() + try: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000,0) + self._paraminsert(cur) # Make sure the cursor still works + finally: + con.close() + + def test_setoutputsize(self): + # Real test for setoutputsize is driver dependent + raise NotImplementedError('Driver needed to override this test') + + def test_None(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchall() + self.assertEqual(len(r),1) + self.assertEqual(len(r[0]),1) + self.assertEqual(r[0][0],None,'NULL value not returned as None') + finally: + con.close() + + def test_Date(self): + d1 = self.driver.Date(2002,12,25) + d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(d1),str(d2)) + + def test_Time(self): + t1 = self.driver.Time(13,45,30) + t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Timestamp(self): + t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t2 = self.driver.TimestampFromTicks( + time.mktime((2002,12,25,13,45,30,0,0,0)) + ) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Binary(self): + b = self.driver.Binary(b'Something') + b = self.driver.Binary(b'') + + def test_STRING(self): + self.failUnless(hasattr(self.driver,'STRING'), + 'module.STRING must be defined' + ) + + def test_BINARY(self): + self.failUnless(hasattr(self.driver,'BINARY'), + 'module.BINARY must be defined.' + ) + + def test_NUMBER(self): + self.failUnless(hasattr(self.driver,'NUMBER'), + 'module.NUMBER must be defined.' + ) + + def test_DATETIME(self): + self.failUnless(hasattr(self.driver,'DATETIME'), + 'module.DATETIME must be defined.' + ) + + def test_ROWID(self): + self.failUnless(hasattr(self.driver,'ROWID'), + 'module.ROWID must be defined.' + ) +# fmt: on diff --git a/tests/dbapi20_tpc.py b/tests/dbapi20_tpc.py new file mode 100644 index 0000000..7254294 --- /dev/null +++ b/tests/dbapi20_tpc.py @@ -0,0 +1,151 @@ +# flake8: noqa +# fmt: off + +""" Python DB API 2.0 driver Two Phase Commit compliance test suite. + +""" + +import unittest +from typing import Any + + +class TwoPhaseCommitTests(unittest.TestCase): + + driver: Any = None + + def connect(self): + """Make a database connection.""" + raise NotImplementedError + + _last_id = 0 + _global_id_prefix = "dbapi20_tpc:" + + def make_xid(self, con): + id = TwoPhaseCommitTests._last_id + TwoPhaseCommitTests._last_id += 1 + return con.xid(42, f"{self._global_id_prefix}{id}", "qualifier") + + def test_xid(self): + con = self.connect() + try: + try: + xid = con.xid(42, "global", "bqual") + except self.driver.NotSupportedError: + self.fail("Driver does not support transaction IDs.") + + self.assertEquals(xid[0], 42) + self.assertEquals(xid[1], "global") + self.assertEquals(xid[2], "bqual") + + # Try some extremes for the transaction ID: + xid = con.xid(0, "", "") + self.assertEquals(tuple(xid), (0, "", "")) + xid = con.xid(0x7fffffff, "a" * 64, "b" * 64) + self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64)) + finally: + con.close() + + def test_tpc_begin(self): + con = self.connect() + try: + xid = self.make_xid(con) + try: + con.tpc_begin(xid) + except self.driver.NotSupportedError: + self.fail("Driver does not support tpc_begin()") + finally: + con.close() + + def test_tpc_commit_without_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_commit() + finally: + con.close() + + def test_tpc_rollback_without_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_rollback() + finally: + con.close() + + def test_tpc_commit_with_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_prepare() + con.tpc_commit() + finally: + con.close() + + def test_tpc_rollback_with_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_prepare() + con.tpc_rollback() + finally: + con.close() + + def test_tpc_begin_in_transaction_fails(self): + con = self.connect() + try: + xid = self.make_xid(con) + + cursor = con.cursor() + cursor.execute("SELECT 1") + self.assertRaises(self.driver.ProgrammingError, + con.tpc_begin, xid) + finally: + con.close() + + def test_tpc_begin_in_tpc_transaction_fails(self): + con = self.connect() + try: + xid = self.make_xid(con) + + cursor = con.cursor() + cursor.execute("SELECT 1") + self.assertRaises(self.driver.ProgrammingError, + con.tpc_begin, xid) + finally: + con.close() + + def test_commit_in_tpc_fails(self): + # calling commit() within a TPC transaction fails with + # ProgrammingError. + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + + self.assertRaises(self.driver.ProgrammingError, con.commit) + finally: + con.close() + + def test_rollback_in_tpc_fails(self): + # calling rollback() within a TPC transaction fails with + # ProgrammingError. + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + + self.assertRaises(self.driver.ProgrammingError, con.rollback) + finally: + con.close() diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py new file mode 100644 index 0000000..88ab504 --- /dev/null +++ b/tests/fix_crdb.py @@ -0,0 +1,131 @@ +from typing import Optional + +import pytest + +from .utils import VersionCheck +from psycopg.crdb import CrdbConnection + + +def pytest_configure(config): + # register libpq marker + config.addinivalue_line( + "markers", + "crdb(version_expr, reason=detail): run/skip the test with matching CockroachDB" + " (e.g. '>= 21.2.10', '< 22.1', 'skip < 22')", + ) + config.addinivalue_line( + "markers", + "crdb_skip(reason): skip the test for known CockroachDB reasons", + ) + + +def check_crdb_version(got, mark): + if mark.name == "crdb": + assert len(mark.args) <= 1 + assert not (set(mark.kwargs) - {"reason"}) + spec = mark.args[0] if mark.args else "only" + reason = mark.kwargs.get("reason") + elif mark.name == "crdb_skip": + assert len(mark.args) == 1 + assert not mark.kwargs + reason = mark.args[0] + assert reason in _crdb_reasons, reason + spec = _crdb_reason_version.get(reason, "skip") + else: + assert False, mark.name + + pred = VersionCheck.parse(spec) + pred.whose = "CockroachDB" + + msg = pred.get_skip_message(got) + if not msg: + return None + + reason = crdb_skip_message(reason) + if reason: + msg = f"{msg}: {reason}" + + return msg + + +# Utility functions which can be imported in the test suite + +is_crdb = CrdbConnection.is_crdb + + +def crdb_skip_message(reason: Optional[str]) -> str: + msg = "" + if reason: + msg = reason + if _crdb_reasons.get(reason): + url = ( + "https://github.com/cockroachdb/cockroach/" + f"issues/{_crdb_reasons[reason]}" + ) + msg = f"{msg} ({url})" + + return msg + + +def skip_crdb(*args, reason=None): + return pytest.param(*args, marks=pytest.mark.crdb("skip", reason=reason)) + + +def crdb_encoding(*args): + """Mark tests that fail on CockroachDB because of missing encodings""" + return skip_crdb(*args, reason="encoding") + + +def crdb_time_precision(*args): + """Mark tests that fail on CockroachDB because time doesn't support precision""" + return skip_crdb(*args, reason="time precision") + + +def crdb_scs_off(*args): + return skip_crdb(*args, reason="standard_conforming_strings=off") + + +# mapping from reason description to ticket number +_crdb_reasons = { + "2-phase commit": 22329, + "backend pid": 35897, + "batch statements": 44803, + "begin_read_only": 87012, + "binary decimal": 82492, + "cancel": 41335, + "cast adds tz": 51692, + "cidr": 18846, + "composite": 27792, + "copy array": 82792, + "copy canceled": 81559, + "copy": 41608, + "cursor invalid name": 84261, + "cursor with hold": 77101, + "deferrable": 48307, + "do": 17511, + "encoding": 35882, + "geometric types": 21286, + "hstore": 41284, + "infinity date": 41564, + "interval style": 35807, + "json array": 23468, + "large objects": 243, + "negative interval": 81577, + "nested array": 32552, + "no col query": None, + "notify": 41522, + "password_encryption": 42519, + "pg_terminate_backend": 35897, + "range": 41282, + "scroll cursor": 77102, + "server-side cursor": 41412, + "severity_nonlocalized": 81794, + "stored procedure": 1751, +} + +_crdb_reason_version = { + "backend pid": "skip < 22", + "cancel": "skip < 22", + "server-side cursor": "skip < 22.1.3", + "severity_nonlocalized": "skip < 22.1.3", +} diff --git a/tests/fix_db.py b/tests/fix_db.py new file mode 100644 index 0000000..3a37aa1 --- /dev/null +++ b/tests/fix_db.py @@ -0,0 +1,358 @@ +import io +import os +import sys +import pytest +import logging +from contextlib import contextmanager +from typing import Optional + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg._compat import cache +from psycopg.pq._debug import PGconnDebug + +from .utils import check_postgres_version + +# Set by warm_up_database() the first time the dsn fixture is used +pg_version: int +crdb_version: Optional[int] + + +def pytest_addoption(parser): + parser.addoption( + "--test-dsn", + metavar="DSN", + default=os.environ.get("PSYCOPG_TEST_DSN"), + help=( + "Connection string to run database tests requiring a connection" + " [you can also use the PSYCOPG_TEST_DSN env var]." + ), + ) + parser.addoption( + "--pq-trace", + metavar="{TRACEFILE,STDERR}", + default=None, + help="Generate a libpq trace to TRACEFILE or STDERR.", + ) + parser.addoption( + "--pq-debug", + action="store_true", + default=False, + help="Log PGconn access. (Requires PSYCOPG_IMPL=python.)", + ) + + +def pytest_report_header(config): + dsn = config.getoption("--test-dsn") + if dsn is None: + return [] + + try: + with psycopg.connect(dsn, connect_timeout=10) as conn: + server_version = conn.execute("select version()").fetchall()[0][0] + except Exception as ex: + server_version = f"unknown ({ex})" + + return [ + f"Server version: {server_version}", + ] + + +def pytest_collection_modifyitems(items): + for item in items: + for name in item.fixturenames: + if name in ("pipeline", "apipeline"): + item.add_marker(pytest.mark.pipeline) + break + + +def pytest_runtest_setup(item): + for m in item.iter_markers(name="pipeline"): + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + + +def pytest_configure(config): + # register pg marker + markers = [ + "pg(version_expr): run the test only with matching server version" + " (e.g. '>= 10', '< 9.6')", + "pipeline: the test runs with connection in pipeline mode", + ] + for marker in markers: + config.addinivalue_line("markers", marker) + + +@pytest.fixture(scope="session") +def session_dsn(request): + """ + Return the dsn used to connect to the `--test-dsn` database (session-wide). + """ + dsn = request.config.getoption("--test-dsn") + if dsn is None: + pytest.skip("skipping test as no --test-dsn") + + warm_up_database(dsn) + return dsn + + +@pytest.fixture +def dsn(session_dsn, request): + """Return the dsn used to connect to the `--test-dsn` database.""" + check_connection_version(request.node) + return session_dsn + + +@pytest.fixture(scope="session") +def tracefile(request): + """Open and yield a file for libpq client/server communication traces if + --pq-tracefile option is set. + """ + tracefile = request.config.getoption("--pq-trace") + if not tracefile: + yield None + return + + if tracefile.lower() == "stderr": + try: + sys.stderr.fileno() + except io.UnsupportedOperation: + raise pytest.UsageError( + "cannot use stderr for --pq-trace (in-memory file?)" + ) from None + + yield sys.stderr + return + + with open(tracefile, "w") as f: + yield f + + +@contextmanager +def maybe_trace(pgconn, tracefile, function): + """Handle libpq client/server communication traces for a single test + function. + """ + if tracefile is None: + yield None + return + + if tracefile != sys.stderr: + title = f" {function.__module__}::{function.__qualname__} ".center(80, "=") + tracefile.write(title + "\n") + tracefile.flush() + + pgconn.trace(tracefile.fileno()) + try: + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + except psycopg.NotSupportedError: + pass + try: + yield None + finally: + pgconn.untrace() + + +@pytest.fixture(autouse=True) +def pgconn_debug(request): + if not request.config.getoption("--pq-debug"): + return + if pq.__impl__ != "python": + raise pytest.UsageError("set PSYCOPG_IMPL=python to use --pq-debug") + logging.basicConfig(level=logging.INFO, format="%(message)s") + logger = logging.getLogger("psycopg.debug") + logger.setLevel(logging.INFO) + pq.PGconn = PGconnDebug + + +@pytest.fixture +def pgconn(dsn, request, tracefile): + """Return a PGconn connection open to `--test-dsn`.""" + check_connection_version(request.node) + + conn = pq.PGconn.connect(dsn.encode()) + if conn.status != pq.ConnStatus.OK: + pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}") + + with maybe_trace(conn, tracefile, request.function): + yield conn + + conn.finish() + + +@pytest.fixture +def conn(conn_cls, dsn, request, tracefile): + """Return a `Connection` connected to the ``--test-dsn`` database.""" + check_connection_version(request.node) + + conn = conn_cls.connect(dsn) + with maybe_trace(conn.pgconn, tracefile, request.function): + yield conn + conn.close() + + +@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"]) +def pipeline(request, conn): + if request.param: + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + with conn.pipeline() as p: + yield p + return + else: + yield None + + +@pytest.fixture +async def aconn(dsn, aconn_cls, request, tracefile): + """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" + check_connection_version(request.node) + + conn = await aconn_cls.connect(dsn) + with maybe_trace(conn.pgconn, tracefile, request.function): + yield conn + await conn.close() + + +@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"]) +async def apipeline(request, aconn): + if request.param: + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + async with aconn.pipeline() as p: + yield p + return + else: + yield None + + +@pytest.fixture(scope="session") +def conn_cls(session_dsn): + cls = psycopg.Connection + if crdb_version: + from psycopg.crdb import CrdbConnection + + cls = CrdbConnection + + return cls + + +@pytest.fixture(scope="session") +def aconn_cls(session_dsn): + cls = psycopg.AsyncConnection + if crdb_version: + from psycopg.crdb import AsyncCrdbConnection + + cls = AsyncCrdbConnection + + return cls + + +@pytest.fixture(scope="session") +def svcconn(conn_cls, session_dsn): + """ + Return a session `Connection` connected to the ``--test-dsn`` database. + """ + conn = conn_cls.connect(session_dsn, autocommit=True) + yield conn + conn.close() + + +@pytest.fixture +def commands(conn, monkeypatch): + """The list of commands issued internally by the test connection.""" + yield patch_exec(conn, monkeypatch) + + +@pytest.fixture +def acommands(aconn, monkeypatch): + """The list of commands issued internally by the test async connection.""" + yield patch_exec(aconn, monkeypatch) + + +def patch_exec(conn, monkeypatch): + """Helper to implement the commands fixture both sync and async.""" + _orig_exec_command = conn._exec_command + L = ListPopAll() + + def _exec_command(command, *args, **kwargs): + cmdcopy = command + if isinstance(cmdcopy, bytes): + cmdcopy = cmdcopy.decode(conn.info.encoding) + elif isinstance(cmdcopy, sql.Composable): + cmdcopy = cmdcopy.as_string(conn) + + L.append(cmdcopy) + return _orig_exec_command(command, *args, **kwargs) + + monkeypatch.setattr(conn, "_exec_command", _exec_command) + return L + + +class ListPopAll(list): # type: ignore[type-arg] + """A list, with a popall() method.""" + + def popall(self): + out = self[:] + del self[:] + return out + + +def check_connection_version(node): + try: + pg_version + except NameError: + # First connection creation failed. Let the tests fail. + pytest.fail("server version not available") + + for mark in node.iter_markers(): + if mark.name == "pg": + assert len(mark.args) == 1 + msg = check_postgres_version(pg_version, mark.args[0]) + if msg: + pytest.skip(msg) + + elif mark.name in ("crdb", "crdb_skip"): + from .fix_crdb import check_crdb_version + + msg = check_crdb_version(crdb_version, mark) + if msg: + pytest.skip(msg) + + +@pytest.fixture +def hstore(svcconn): + try: + with svcconn.transaction(): + svcconn.execute("create extension if not exists hstore") + except psycopg.Error as e: + pytest.skip(str(e)) + + +@cache +def warm_up_database(dsn: str) -> None: + """Connect to the database before returning a connection. + + In the CI sometimes, the first test fails with a timeout, probably because + the server hasn't started yet. Absorb the delay before the test. + + In case of error, abort the test run entirely, to avoid failing downstream + hundreds of times. + """ + global pg_version, crdb_version + + try: + with psycopg.connect(dsn, connect_timeout=10) as conn: + conn.execute("select 1") + + pg_version = conn.info.server_version + + crdb_version = None + param = conn.info.parameter_status("crdb_version") + if param: + from psycopg.crdb import CrdbConnectionInfo + + crdb_version = CrdbConnectionInfo.parse_crdb_version(param) + except Exception as exc: + pytest.exit(f"failed to connect to the test database: {exc}") diff --git a/tests/fix_faker.py b/tests/fix_faker.py new file mode 100644 index 0000000..5289d8f --- /dev/null +++ b/tests/fix_faker.py @@ -0,0 +1,868 @@ +import datetime as dt +import importlib +import ipaddress +from math import isnan +from uuid import UUID +from random import choice, random, randrange +from typing import Any, List, Set, Tuple, Union +from decimal import Decimal +from contextlib import contextmanager, asynccontextmanager + +import pytest + +import psycopg +from psycopg import sql +from psycopg.adapt import PyFormat +from psycopg._compat import Deque +from psycopg.types.range import Range +from psycopg.types.json import Json, Jsonb +from psycopg.types.numeric import Int4, Int8 +from psycopg.types.multirange import Multirange + + +@pytest.fixture +def faker(conn): + return Faker(conn) + + +class Faker: + """ + An object to generate random records. + """ + + json_max_level = 3 + json_max_length = 10 + str_max_length = 100 + list_max_length = 20 + tuple_max_length = 15 + + def __init__(self, connection): + self.conn = connection + self.format = PyFormat.BINARY + self.records = [] + + self._schema = None + self._types = None + self._types_names = None + self._makers = {} + self.table_name = sql.Identifier("fake_table") + + @property + def schema(self): + if not self._schema: + self.schema = self.choose_schema() + return self._schema + + @schema.setter + def schema(self, schema): + self._schema = schema + self._types_names = None + + @property + def fields_names(self): + return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))] + + @property + def types(self): + if not self._types: + + def key(cls: type) -> str: + return cls.__name__ + + self._types = sorted(self.get_supported_types(), key=key) + return self._types + + @property + def types_names_sql(self): + if self._types_names: + return self._types_names + + record = self.make_record(nulls=0) + tx = psycopg.adapt.Transformer(self.conn) + types = [ + self._get_type_name(tx, schema, value) + for schema, value in zip(self.schema, record) + ] + self._types_names = types + return types + + @property + def types_names(self): + types = [t.as_string(self.conn).replace('"', "") for t in self.types_names_sql] + return types + + def _get_type_name(self, tx, schema, value): + # Special case it as it is passed as unknown so is returned as text + if schema == (list, str): + return sql.SQL("text[]") + + registry = self.conn.adapters.types + dumper = tx.get_dumper(value, self.format) + dumper.dump(value) # load the oid if it's dynamic (e.g. array) + info = registry.get(dumper.oid) or registry.get("text") + if dumper.oid == info.array_oid: + return sql.SQL("{}[]").format(sql.Identifier(info.name)) + else: + return sql.Identifier(info.name) + + @property + def drop_stmt(self): + return sql.SQL("drop table if exists {}").format(self.table_name) + + @property + def create_stmt(self): + field_values = [] + for name, type in zip(self.fields_names, self.types_names_sql): + field_values.append(sql.SQL("{} {}").format(name, type)) + + fields = sql.SQL(", ").join(field_values) + return sql.SQL("create table {table} (id serial primary key, {fields})").format( + table=self.table_name, fields=fields + ) + + @property + def insert_stmt(self): + phs = [sql.Placeholder(format=self.format) for i in range(len(self.schema))] + return sql.SQL("insert into {} ({}) values ({})").format( + self.table_name, + sql.SQL(", ").join(self.fields_names), + sql.SQL(", ").join(phs), + ) + + @property + def select_stmt(self): + fields = sql.SQL(", ").join(self.fields_names) + return sql.SQL("select {} from {} order by id").format(fields, self.table_name) + + @contextmanager + def find_insert_problem(self, conn): + """Context manager to help finding a problematic value.""" + try: + with conn.transaction(): + yield + except psycopg.DatabaseError: + cur = conn.cursor() + # Repeat insert one field at time, until finding the wrong one + cur.execute(self.drop_stmt) + cur.execute(self.create_stmt) + for i, rec in enumerate(self.records): + for j, val in enumerate(rec): + try: + cur.execute(self._insert_field_stmt(j), (val,)) + except psycopg.DatabaseError as e: + r = repr(val) + if len(r) > 200: + r = f"{r[:200]}... ({len(r)} chars)" + raise Exception( + f"value {r!r} at record {i} column0 {j} failed insert: {e}" + ) from None + + # just in case, but hopefully we should have triggered the problem + raise + + @asynccontextmanager + async def find_insert_problem_async(self, aconn): + try: + async with aconn.transaction(): + yield + except psycopg.DatabaseError: + acur = aconn.cursor() + # Repeat insert one field at time, until finding the wrong one + await acur.execute(self.drop_stmt) + await acur.execute(self.create_stmt) + for i, rec in enumerate(self.records): + for j, val in enumerate(rec): + try: + await acur.execute(self._insert_field_stmt(j), (val,)) + except psycopg.DatabaseError as e: + r = repr(val) + if len(r) > 200: + r = f"{r[:200]}... ({len(r)} chars)" + raise Exception( + f"value {r!r} at record {i} column0 {j} failed insert: {e}" + ) from None + + # just in case, but hopefully we should have triggered the problem + raise + + def _insert_field_stmt(self, i): + ph = sql.Placeholder(format=self.format) + return sql.SQL("insert into {} ({}) values ({})").format( + self.table_name, self.fields_names[i], ph + ) + + def choose_schema(self, ncols=20): + schema: List[Union[Tuple[type, ...], type]] = [] + while len(schema) < ncols: + s = self.make_schema(choice(self.types)) + if s is not None: + schema.append(s) + self.schema = schema + return schema + + def make_records(self, nrecords): + self.records = [self.make_record(nulls=0.05) for i in range(nrecords)] + + def make_record(self, nulls=0): + if not nulls: + return tuple(self.example(spec) for spec in self.schema) + else: + return tuple( + self.make(spec) if random() > nulls else None for spec in self.schema + ) + + def assert_record(self, got, want): + for spec, g, w in zip(self.schema, got, want): + if g is None and w is None: + continue + m = self.get_matcher(spec) + m(spec, g, w) + + def get_supported_types(self) -> Set[type]: + dumpers = self.conn.adapters._dumpers[self.format] + rv = set() + for cls in dumpers.keys(): + if isinstance(cls, str): + cls = deep_import(cls) + if issubclass(cls, Multirange) and self.conn.info.server_version < 140000: + continue + + rv.add(cls) + + # check all the types are handled + for cls in rv: + self.get_maker(cls) + + return rv + + def make_schema(self, cls: type) -> Union[Tuple[type, ...], type, None]: + """Create a schema spec from a Python type. + + A schema specifies what Postgres type to generate when a Python type + maps to more than one (e.g. tuple -> composite, list -> array[], + datetime -> timestamp[tz]). + + A schema for a type is represented by a tuple (type, ...) which the + matching make_*() method can interpret, or just type if the type + doesn't require further specification. + + A `None` means that the type is not supported. + """ + meth = self._get_method("schema", cls) + return meth(cls) if meth else cls + + def get_maker(self, spec): + cls = spec if isinstance(spec, type) else spec[0] + + try: + return self._makers[cls] + except KeyError: + pass + + meth = self._get_method("make", cls) + if meth: + self._makers[cls] = meth + return meth + else: + raise NotImplementedError(f"cannot make fake objects of class {cls}") + + def get_matcher(self, spec): + cls = spec if isinstance(spec, type) else spec[0] + meth = self._get_method("match", cls) + return meth if meth else self.match_any + + def _get_method(self, prefix, cls): + name = cls.__name__ + if cls.__module__ != "builtins": + name = f"{cls.__module__}.{name}" + + parts = name.split(".") + for i in range(len(parts)): + mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}" + meth = getattr(self, mname, None) + if meth: + return meth + + return None + + def make(self, spec): + # spec can be a type or a tuple (type, options) + return self.get_maker(spec)(spec) + + def example(self, spec): + # A good representative of the object - no degenerate case + cls = spec if isinstance(spec, type) else spec[0] + meth = self._get_method("example", cls) + if meth: + return meth(spec) + else: + return self.make(spec) + + def match_any(self, spec, got, want): + assert got == want + + # methods to generate samples of specific types + + def make_Binary(self, spec): + return self.make_bytes(spec) + + def match_Binary(self, spec, got, want): + return want.obj == got + + def make_bool(self, spec): + return choice((True, False)) + + def make_bytearray(self, spec): + return self.make_bytes(spec) + + def make_bytes(self, spec): + length = randrange(self.str_max_length) + return spec(bytes([randrange(256) for i in range(length)])) + + def make_date(self, spec): + day = randrange(dt.date.max.toordinal()) + return dt.date.fromordinal(day + 1) + + def schema_datetime(self, cls): + return self.schema_time(cls) + + def make_datetime(self, spec): + # Add a day because with timezone we might go BC + dtmin = dt.datetime.min + dt.timedelta(days=1) + delta = dt.datetime.max - dtmin + micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000) + rv = dtmin + dt.timedelta(microseconds=micros) + if spec[1]: + rv = rv.replace(tzinfo=self._make_tz(spec)) + return rv + + def match_datetime(self, spec, got, want): + # Comparisons with different timezones is unreliable: certain pairs + # are reported different but their delta is 0 + # https://bugs.python.org/issue45347 + assert not (got - want) + + def make_Decimal(self, spec): + if random() >= 0.99: + return Decimal(choice(self._decimal_special_values())) + + sign = choice("+-") + num = choice(["0.zd", "d", "d.d"]) + while "z" in num: + ndigits = randrange(1, 20) + num = num.replace("z", "0" * ndigits, 1) + while "d" in num: + ndigits = randrange(1, 20) + num = num.replace( + "d", "".join([str(randrange(10)) for i in range(ndigits)]), 1 + ) + expsign = choice(["e+", "e-", ""]) + exp = randrange(20) if expsign else "" + rv = Decimal(f"{sign}{num}{expsign}{exp}") + return rv + + def match_Decimal(self, spec, got, want): + if got is not None and got.is_nan(): + assert want.is_nan() + else: + assert got == want + + def _decimal_special_values(self): + values = ["NaN", "sNaN"] + + if self.conn.info.vendor == "PostgreSQL": + if self.conn.info.server_version >= 140000: + values.extend(["Inf", "-Inf"]) + elif self.conn.info.vendor == "CockroachDB": + if self.conn.info.server_version >= 220100: + values.extend(["Inf", "-Inf"]) + else: + pytest.fail(f"unexpected vendor: {self.conn.info.vendor}") + + return values + + def schema_Enum(self, cls): + # TODO: can't fake those as we would need to create temporary types + return None + + def make_Enum(self, spec): + return None + + def make_float(self, spec, double=True): + if random() <= 0.99: + # These exponents should generate no inf + return float( + f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}" + if double + else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}" + ) + else: + return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan"))) + + def match_float(self, spec, got, want, approx=False, rel=None): + if got is not None and isnan(got): + assert isnan(want) + else: + if approx or self._server_rounds(): + assert got == pytest.approx(want, rel=rel) + else: + assert got == want + + def _server_rounds(self): + """Return True if the connected server perform float rounding""" + if self.conn.info.vendor == "CockroachDB": + return True + else: + # Versions older than 12 make some rounding. e.g. in Postgres 10.4 + # select '-1.409006204063909e+112'::float8 + # -> -1.40900620406391e+112 + return self.conn.info.server_version < 120000 + + def make_Float4(self, spec): + return spec(self.make_float(spec, double=False)) + + def match_Float4(self, spec, got, want): + self.match_float(spec, got, want, approx=True, rel=1e-5) + + def make_Float8(self, spec): + return spec(self.make_float(spec)) + + match_Float8 = match_float + + def make_int(self, spec): + return randrange(-(1 << 90), 1 << 90) + + def make_Int2(self, spec): + return spec(randrange(-(1 << 15), 1 << 15)) + + def make_Int4(self, spec): + return spec(randrange(-(1 << 31), 1 << 31)) + + def make_Int8(self, spec): + return spec(randrange(-(1 << 63), 1 << 63)) + + def make_IntNumeric(self, spec): + return spec(randrange(-(1 << 100), 1 << 100)) + + def make_IPv4Address(self, spec): + return ipaddress.IPv4Address(bytes(randrange(256) for _ in range(4))) + + def make_IPv4Interface(self, spec): + prefix = randrange(32) + return ipaddress.IPv4Interface( + (bytes(randrange(256) for _ in range(4)), prefix) + ) + + def make_IPv4Network(self, spec): + return self.make_IPv4Interface(spec).network + + def make_IPv6Address(self, spec): + return ipaddress.IPv6Address(bytes(randrange(256) for _ in range(16))) + + def make_IPv6Interface(self, spec): + prefix = randrange(128) + return ipaddress.IPv6Interface( + (bytes(randrange(256) for _ in range(16)), prefix) + ) + + def make_IPv6Network(self, spec): + return self.make_IPv6Interface(spec).network + + def make_Json(self, spec): + return spec(self._make_json()) + + def match_Json(self, spec, got, want): + if want is not None: + want = want.obj + assert got == want + + def make_Jsonb(self, spec): + return spec(self._make_json()) + + def match_Jsonb(self, spec, got, want): + self.match_Json(spec, got, want) + + def make_JsonFloat(self, spec): + # A float limited to what json accepts + # this exponent should generate no inf + return float(f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}") + + def schema_list(self, cls): + while True: + scls = choice(self.types) + if scls is cls: + continue + if scls is float: + # TODO: float lists are currently adapted as decimal. + # There may be rounding errors or problems with inf. + continue + + # CRDB doesn't support arrays of json + # https://github.com/cockroachdb/cockroach/issues/23468 + if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb): + continue + + schema = self.make_schema(scls) + if schema is not None: + break + + return (cls, schema) + + def make_list(self, spec): + # don't make empty lists because they regularly fail cast + length = randrange(1, self.list_max_length) + spec = spec[1] + while True: + rv = [self.make(spec) for i in range(length)] + + # TODO multirange lists fail binary dump if the last element is + # empty and there is no type annotation. See xfail in + # test_multirange::test_dump_builtin_array + if rv and isinstance(rv[-1], Multirange) and not rv[-1]: + continue + + return rv + + def example_list(self, spec): + return [self.example(spec[1])] + + def match_list(self, spec, got, want): + assert len(got) == len(want) + m = self.get_matcher(spec[1]) + for g, w in zip(got, want): + m(spec[1], g, w) + + def make_memoryview(self, spec): + return self.make_bytes(spec) + + def schema_Multirange(self, cls): + return self.schema_Range(cls) + + def make_Multirange(self, spec, length=None, **kwargs): + if length is None: + length = randrange(0, self.list_max_length) + + def overlap(r1, r2): + l1, u1 = r1.lower, r1.upper + l2, u2 = r2.lower, r2.upper + if l1 is None and l2 is None: + return True + elif l1 is None: + l1 = l2 + elif l2 is None: + l2 = l1 + + if u1 is None and u2 is None: + return True + elif u1 is None: + u1 = u2 + elif u2 is None: + u2 = u1 + + return l1 <= u2 and l2 <= u1 + + out: List[Range[Any]] = [] + for i in range(length): + r = self.make_Range((Range, spec[1]), **kwargs) + if r.isempty: + continue + for r2 in out: + if overlap(r, r2): + insert = False + break + else: + insert = True + if insert: + out.append(r) # alternatively, we could merge + + return spec[0](sorted(out)) + + def example_Multirange(self, spec): + return self.make_Multirange(spec, length=1, empty_chance=0, no_bound_chance=0) + + def make_Int4Multirange(self, spec): + return self.make_Multirange((spec, Int4)) + + def make_Int8Multirange(self, spec): + return self.make_Multirange((spec, Int8)) + + def make_NumericMultirange(self, spec): + return self.make_Multirange((spec, Decimal)) + + def make_DateMultirange(self, spec): + return self.make_Multirange((spec, dt.date)) + + def make_TimestampMultirange(self, spec): + return self.make_Multirange((spec, (dt.datetime, False))) + + def make_TimestamptzMultirange(self, spec): + return self.make_Multirange((spec, (dt.datetime, True))) + + def match_Multirange(self, spec, got, want): + assert len(got) == len(want) + for ig, iw in zip(got, want): + self.match_Range(spec, ig, iw) + + def match_Int4Multirange(self, spec, got, want): + return self.match_Multirange((spec, Int4), got, want) + + def match_Int8Multirange(self, spec, got, want): + return self.match_Multirange((spec, Int8), got, want) + + def match_NumericMultirange(self, spec, got, want): + return self.match_Multirange((spec, Decimal), got, want) + + def match_DateMultirange(self, spec, got, want): + return self.match_Multirange((spec, dt.date), got, want) + + def match_TimestampMultirange(self, spec, got, want): + return self.match_Multirange((spec, (dt.datetime, False)), got, want) + + def match_TimestamptzMultirange(self, spec, got, want): + return self.match_Multirange((spec, (dt.datetime, True)), got, want) + + def schema_NoneType(self, cls): + return None + + def make_NoneType(self, spec): + return None + + def make_Oid(self, spec): + return spec(randrange(1 << 32)) + + def schema_Range(self, cls): + subtypes = [ + Decimal, + Int4, + Int8, + dt.date, + (dt.datetime, True), + (dt.datetime, False), + ] + + return (cls, choice(subtypes)) + + def make_Range(self, spec, empty_chance=0.02, no_bound_chance=0.05): + # TODO: drop format check after fixing binary dumping of empty ranges + # (an array starting with an empty range will get the wrong type currently) + if ( + random() < empty_chance + and spec[0] is Range + and self.format == PyFormat.TEXT + ): + return spec[0](empty=True) + + while True: + bounds: List[Union[Any, None]] = [] + while len(bounds) < 2: + if random() < no_bound_chance: + bounds.append(None) + continue + + val = self.make(spec[1]) + # NaN are allowed in a range, but comparison in Python get tricky. + if spec[1] is Decimal and val.is_nan(): + continue + + bounds.append(val) + + if bounds[0] is not None and bounds[1] is not None: + if bounds[0] == bounds[1]: + # It would come out empty + continue + + if bounds[0] > bounds[1]: + bounds.reverse() + + # avoid generating ranges with no type info if dumping in binary + # TODO: lift this limitation after test_copy_in_empty xfail is fixed + if spec[0] is Range and self.format == PyFormat.BINARY: + if bounds[0] is bounds[1] is None: + continue + + break + + r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])")) + return r + + def example_Range(self, spec): + return self.make_Range(spec, empty_chance=0, no_bound_chance=0) + + def make_Int4Range(self, spec): + return self.make_Range((spec, Int4)) + + def make_Int8Range(self, spec): + return self.make_Range((spec, Int8)) + + def make_NumericRange(self, spec): + return self.make_Range((spec, Decimal)) + + def make_DateRange(self, spec): + return self.make_Range((spec, dt.date)) + + def make_TimestampRange(self, spec): + return self.make_Range((spec, (dt.datetime, False))) + + def make_TimestamptzRange(self, spec): + return self.make_Range((spec, (dt.datetime, True))) + + def match_Range(self, spec, got, want): + # normalise the bounds of unbounded ranges + if want.lower is None and want.lower_inc: + want = type(want)(want.lower, want.upper, "(" + want.bounds[1]) + if want.upper is None and want.upper_inc: + want = type(want)(want.lower, want.upper, want.bounds[0] + ")") + + # Normalise discrete ranges + unit: Union[dt.timedelta, int, None] + if spec[1] is dt.date: + unit = dt.timedelta(days=1) + elif type(spec[1]) is type and issubclass(spec[1], int): + unit = 1 + else: + unit = None + + if unit is not None: + if want.lower is not None and not want.lower_inc: + want = type(want)(want.lower + unit, want.upper, "[" + want.bounds[1]) + if want.upper_inc: + want = type(want)(want.lower, want.upper + unit, want.bounds[0] + ")") + + if spec[1] == (dt.datetime, True) and not want.isempty: + # work around https://bugs.python.org/issue45347 + def fix_dt(x): + return x.astimezone(dt.timezone.utc) if x is not None else None + + def fix_range(r): + return type(r)(fix_dt(r.lower), fix_dt(r.upper), r.bounds) + + want = fix_range(want) + got = fix_range(got) + + assert got == want + + def match_Int4Range(self, spec, got, want): + return self.match_Range((spec, Int4), got, want) + + def match_Int8Range(self, spec, got, want): + return self.match_Range((spec, Int8), got, want) + + def match_NumericRange(self, spec, got, want): + return self.match_Range((spec, Decimal), got, want) + + def match_DateRange(self, spec, got, want): + return self.match_Range((spec, dt.date), got, want) + + def match_TimestampRange(self, spec, got, want): + return self.match_Range((spec, (dt.datetime, False)), got, want) + + def match_TimestamptzRange(self, spec, got, want): + return self.match_Range((spec, (dt.datetime, True)), got, want) + + def make_str(self, spec, length=0): + if not length: + length = randrange(self.str_max_length) + + rv: List[int] = [] + while len(rv) < length: + c = randrange(1, 128) if random() < 0.5 else randrange(1, 0x110000) + if not (0xD800 <= c <= 0xDBFF or 0xDC00 <= c <= 0xDFFF): + rv.append(c) + + return "".join(map(chr, rv)) + + def schema_time(self, cls): + # Choose timezone yes/no + return (cls, choice([True, False])) + + def make_time(self, spec): + val = randrange(24 * 60 * 60 * 1_000_000) + val, ms = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + tz = self._make_tz(spec) if spec[1] else None + return dt.time(h, m, s, ms, tz) + + CRDB_TIMEDELTA_MAX = dt.timedelta(days=1281239) + + def make_timedelta(self, spec): + if self.conn.info.vendor == "CockroachDB": + rng = [-self.CRDB_TIMEDELTA_MAX, self.CRDB_TIMEDELTA_MAX] + else: + rng = [dt.timedelta.min, dt.timedelta.max] + + return choice(rng) * random() + + def schema_tuple(self, cls): + # TODO: this is a complicated matter as it would involve creating + # temporary composite types. + # length = randrange(1, self.tuple_max_length) + # return (cls, self.make_random_schema(ncols=length)) + return None + + def make_tuple(self, spec): + return tuple(self.make(s) for s in spec[1]) + + def match_tuple(self, spec, got, want): + assert len(got) == len(want) == len(spec[1]) + for g, w, s in zip(got, want, spec): + if g is None or w is None: + assert g is w + else: + m = self.get_matcher(s) + m(s, g, w) + + def make_UUID(self, spec): + return UUID(bytes=bytes([randrange(256) for i in range(16)])) + + def _make_json(self, container_chance=0.66): + rec_types = [list, dict] + scal_types = [type(None), int, JsonFloat, bool, str] + if random() < container_chance: + cls = choice(rec_types) + if cls is list: + return [ + self._make_json(container_chance=container_chance / 2.0) + for i in range(randrange(self.json_max_length)) + ] + elif cls is dict: + return { + self.make_str(str, 15): self._make_json( + container_chance=container_chance / 2.0 + ) + for i in range(randrange(self.json_max_length)) + } + else: + assert False, f"unknown rec type: {cls}" + + else: + cls = choice(scal_types) # type: ignore[assignment] + return self.make(cls) + + def _make_tz(self, spec): + minutes = randrange(-12 * 60, 12 * 60 + 1) + return dt.timezone(dt.timedelta(minutes=minutes)) + + +class JsonFloat: + pass + + +def deep_import(name): + parts = Deque(name.split(".")) + seen = [] + if not parts: + raise ValueError("name must be a dot-separated name") + + seen.append(parts.popleft()) + thing = importlib.import_module(seen[-1]) + while parts: + attr = parts.popleft() + seen.append(attr) + + if hasattr(thing, attr): + thing = getattr(thing, attr) + else: + thing = importlib.import_module(".".join(seen)) + + return thing diff --git a/tests/fix_mypy.py b/tests/fix_mypy.py new file mode 100644 index 0000000..b860a32 --- /dev/null +++ b/tests/fix_mypy.py @@ -0,0 +1,54 @@ +import re +import subprocess as sp + +import pytest + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "mypy: the test uses mypy (the marker is set automatically" + " on tests using the fixture)", + ) + + +def pytest_collection_modifyitems(items): + for item in items: + if "mypy" in item.fixturenames: + # add a mypy tag so we can address these tests only + item.add_marker(pytest.mark.mypy) + + # All the tests using mypy are slow + item.add_marker(pytest.mark.slow) + + +@pytest.fixture(scope="session") +def mypy(tmp_path_factory): + cache_dir = tmp_path_factory.mktemp(basename="mypy_cache") + src_dir = tmp_path_factory.mktemp("source") + + class MypyRunner: + def run_on_file(self, filename): + cmdline = f""" + mypy + --strict + --show-error-codes --no-color-output --no-error-summary + --config-file= --cache-dir={cache_dir} + """.split() + cmdline.append(filename) + return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT) + + def run_on_source(self, source): + fn = src_dir / "tmp.py" + with fn.open("w") as f: + f.write(source) + + return self.run_on_file(str(fn)) + + def get_revealed(self, line): + """return the type from an output of reveal_type""" + return re.sub( + r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line + ).replace("*", "") + + return MypyRunner() diff --git a/tests/fix_pq.py b/tests/fix_pq.py new file mode 100644 index 0000000..6811a26 --- /dev/null +++ b/tests/fix_pq.py @@ -0,0 +1,141 @@ +import os +import sys +import ctypes +from typing import Iterator, List, NamedTuple +from tempfile import TemporaryFile + +import pytest + +from .utils import check_libpq_version + +try: + from psycopg import pq +except ImportError: + pq = None # type: ignore + + +def pytest_report_header(config): + try: + from psycopg import pq + except ImportError: + return [] + + return [ + f"libpq wrapper implementation: {pq.__impl__}", + f"libpq used: {pq.version()}", + f"libpq compiled: {pq.__build_version__}", + ] + + +def pytest_configure(config): + # register libpq marker + config.addinivalue_line( + "markers", + "libpq(version_expr): run the test only with matching libpq" + " (e.g. '>= 10', '< 9.6')", + ) + + +def pytest_runtest_setup(item): + for m in item.iter_markers(name="libpq"): + assert len(m.args) == 1 + msg = check_libpq_version(pq.version(), m.args[0]) + if msg: + pytest.skip(msg) + + +@pytest.fixture +def libpq(): + """Return a ctypes wrapper to access the libpq.""" + try: + from psycopg.pq.misc import find_libpq_full_path + + # Not available when testing the binary package + libname = find_libpq_full_path() + assert libname, "libpq libname not found" + return ctypes.pydll.LoadLibrary(libname) + except Exception as e: + if pq.__impl__ == "binary": + pytest.skip(f"can't load libpq for testing: {e}") + else: + raise + + +@pytest.fixture +def setpgenv(monkeypatch): + """Replace the PG* env vars with the vars provided.""" + + def setpgenv_(env): + ks = [k for k in os.environ if k.startswith("PG")] + for k in ks: + monkeypatch.delenv(k) + + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) + + return setpgenv_ + + +@pytest.fixture +def trace(libpq): + pqver = pq.__build_version__ + if pqver < 140000: + pytest.skip(f"trace not available on libpq {pqver}") + if sys.platform != "linux": + pytest.skip(f"trace not available on {sys.platform}") + + yield Tracer() + + +class Tracer: + def trace(self, conn): + pgconn: "pq.abc.PGconn" + + if hasattr(conn, "exec_"): + pgconn = conn + elif hasattr(conn, "cursor"): + pgconn = conn.pgconn + else: + raise Exception() + + return TraceLog(pgconn) + + +class TraceLog: + def __init__(self, pgconn: "pq.abc.PGconn"): + self.pgconn = pgconn + self.tempfile = TemporaryFile(buffering=0) + pgconn.trace(self.tempfile.fileno()) + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS) + + def __del__(self): + if self.pgconn.status == pq.ConnStatus.OK: + self.pgconn.untrace() + self.tempfile.close() + + def __iter__(self) -> "Iterator[TraceEntry]": + self.tempfile.seek(0) + data = self.tempfile.read() + for entry in self._parse_entries(data): + yield entry + + def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]": + for line in data.splitlines(): + direction, length, type, *content = line.split(b"\t") + yield TraceEntry( + direction=direction.decode(), + length=int(length.decode()), + type=type.decode(), + # Note: the items encoding is not very solid: no escaped + # backslash, no escaped quotes. + # At the moment we don't need a proper parser. + content=[content[0]] if content else [], + ) + + +class TraceEntry(NamedTuple): + direction: str + length: int + type: str + content: List[bytes] diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py new file mode 100644 index 0000000..e50f5ec --- /dev/null +++ b/tests/fix_proxy.py @@ -0,0 +1,127 @@ +import os +import time +import socket +import logging +import subprocess as sp +from shutil import which + +import pytest + +import psycopg +from psycopg import conninfo + + +def pytest_collection_modifyitems(items): + for item in items: + # TODO: there is a race condition on macOS and Windows in the CI: + # listen returns before really listening and tests based on 'deaf_port' + # fail 50% of the times. Just add the 'proxy' mark on these tests + # because they are already skipped in the CI. + if "proxy" in item.fixturenames or "deaf_port" in item.fixturenames: + item.add_marker(pytest.mark.proxy) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "proxy: the test uses pproxy (the marker is set automatically" + " on tests using the fixture)", + ) + + +@pytest.fixture +def proxy(dsn): + """Return a proxy to the --test-dsn database""" + p = Proxy(dsn) + yield p + p.stop() + + +@pytest.fixture +def deaf_port(dsn): + """Return a port number with a socket open but not answering""" + with socket.socket(socket.AF_INET) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + s.listen(0) + yield port + + +class Proxy: + """ + Proxy a Postgres service for testing purpose. + + Allow to lose connectivity and restart it using stop/start. + """ + + def __init__(self, server_dsn): + cdict = conninfo.conninfo_to_dict(server_dsn) + + # Get server params + host = cdict.get("host") or os.environ.get("PGHOST") + self.server_host = host if host and not host.startswith("/") else "localhost" + self.server_port = cdict.get("port", "5432") + + # Get client params + self.client_host = "localhost" + self.client_port = self._get_random_port() + + # Make a connection string to the proxy + cdict["host"] = self.client_host + cdict["port"] = self.client_port + cdict["sslmode"] = "disable" # not supported by the proxy + self.client_dsn = conninfo.make_conninfo(**cdict) + + # The running proxy process + self.proc = None + + def start(self): + if self.proc: + logging.info("proxy already started") + return + + logging.info("starting proxy") + pproxy = which("pproxy") + if not pproxy: + raise ValueError("pproxy program not found") + cmdline = [pproxy, "--reuse"] + cmdline.extend(["-l", f"tunnel://:{self.client_port}"]) + cmdline.extend(["-r", f"tunnel://{self.server_host}:{self.server_port}"]) + + self.proc = sp.Popen(cmdline, stdout=sp.DEVNULL) + logging.info("proxy started") + self._wait_listen() + + # verify that the proxy works + try: + with psycopg.connect(self.client_dsn): + pass + except Exception as e: + pytest.fail(f"failed to create a working proxy: {e}") + + def stop(self): + if not self.proc: + return + + logging.info("stopping proxy") + self.proc.terminate() + self.proc.wait() + logging.info("proxy stopped") + self.proc = None + + @classmethod + def _get_random_port(cls): + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + def _wait_listen(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + for i in range(20): + if 0 == sock.connect_ex((self.client_host, self.client_port)): + break + time.sleep(0.1) + else: + raise ValueError("the proxy didn't start listening in time") + + logging.info("proxy listening") diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py new file mode 100644 index 0000000..80e0c62 --- /dev/null +++ b/tests/fix_psycopg.py @@ -0,0 +1,98 @@ +from copy import deepcopy + +import pytest + + +@pytest.fixture +def global_adapters(): + """Restore the global adapters after a test has changed them.""" + from psycopg import adapters + + dumpers = deepcopy(adapters._dumpers) + dumpers_by_oid = deepcopy(adapters._dumpers_by_oid) + loaders = deepcopy(adapters._loaders) + types = list(adapters.types) + + yield None + + adapters._dumpers = dumpers + adapters._dumpers_by_oid = dumpers_by_oid + adapters._loaders = loaders + adapters.types.clear() + for t in types: + adapters.types.add(t) + + +@pytest.fixture +@pytest.mark.crdb_skip("2-phase commit") +def tpc(svcconn): + tpc = Tpc(svcconn) + tpc.check_tpc() + tpc.clear_test_xacts() + tpc.make_test_table() + yield tpc + tpc.clear_test_xacts() + + +class Tpc: + """Helper object to test two-phase transactions""" + + def __init__(self, conn): + assert conn.autocommit + self.conn = conn + + def check_tpc(self): + from .fix_crdb import is_crdb, crdb_skip_message + + if is_crdb(self.conn): + pytest.skip(crdb_skip_message("2-phase commit")) + + val = int(self.conn.execute("show max_prepared_transactions").fetchone()[0]) + if not val: + pytest.skip("prepared transactions disabled in the database") + + def clear_test_xacts(self): + """Rollback all the prepared transaction in the testing db.""" + from psycopg import sql + + cur = self.conn.execute( + "select gid from pg_prepared_xacts where database = %s", + (self.conn.info.dbname,), + ) + gids = [r[0] for r in cur] + for gid in gids: + self.conn.execute(sql.SQL("rollback prepared {}").format(gid)) + + def make_test_table(self): + self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)") + self.conn.execute("TRUNCATE test_tpc") + + def count_xacts(self): + """Return the number of prepared xacts currently in the test db.""" + cur = self.conn.execute( + """ + select count(*) from pg_prepared_xacts + where database = %s""", + (self.conn.info.dbname,), + ) + return cur.fetchone()[0] + + def count_test_records(self): + """Return the number of records in the test table.""" + cur = self.conn.execute("select count(*) from test_tpc") + return cur.fetchone()[0] + + +@pytest.fixture(scope="module") +def generators(): + """Return the 'generators' module for selected psycopg implementation.""" + from psycopg import pq + + if pq.__impl__ == "c": + from psycopg._cmodule import _psycopg + + return _psycopg + else: + import psycopg.generators + + return psycopg.generators diff --git a/tests/pool/__init__.py b/tests/pool/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/pool/__init__.py diff --git a/tests/pool/fix_pool.py b/tests/pool/fix_pool.py new file mode 100644 index 0000000..12e4f39 --- /dev/null +++ b/tests/pool/fix_pool.py @@ -0,0 +1,12 @@ +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "pool: test related to the psycopg_pool package") + + +def pytest_collection_modifyitems(items): + # Add the pool markers to all the tests in the pool package + for item in items: + if "/pool/" in item.nodeid: + item.add_marker(pytest.mark.pool) diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py new file mode 100644 index 0000000..c0e8060 --- /dev/null +++ b/tests/pool/test_null_pool.py @@ -0,0 +1,896 @@ +import logging +from time import sleep, time +from threading import Thread, Event +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus + +from .test_pool import delay_connection, ensure_waiting + +try: + from psycopg_pool import NullConnectionPool + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +def test_defaults(dsn): + with NullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +def test_min_size_max_size(dsn): + with NullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + NullConnectionPool(min_size=min_size, max_size=max_size) + + +def test_connection_class(dsn): + class MyConn(psycopg.Connection[Any]): + pass + + with NullConnectionPool(dsn, connection_class=MyConn) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + +def test_kwargs(dsn): + with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +def test_its_no_pool_at_all(dsn): + with NullConnectionPool(dsn, max_size=2) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + with p.connection() as conn: + assert conn.info.backend_pid not in (pid1, pid2) + + +def test_context(dsn): + with NullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.1) + + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.4) + + +def test_wait_closed(dsn): + with NullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + p.wait() + + +@pytest.mark.slow +def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + p.wait(0.2) + + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() + + with p.connection() as conn: + conn.execute("select 1") + + +def test_configure(dsn): + inits = 0 + + def configure(conn): + nonlocal inits + inits += 1 + with conn.transaction(): + conn.execute("set default_transaction_read_only to on") + + with NullConnectionPool(dsn, configure=configure) as p: + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + with p.connection() as conn: + assert inits == 2 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + conn.close() + + with p.connection() as conn: + assert inits == 3 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + conn.execute("select 1") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + with conn.transaction(): + conn.execute("WAT") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + pids = [] + + def worker(): + with p.connection() as conn: + assert resets == 1 + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + assert resets == 0 + conn.execute("set timezone to '+2:00'") + pids.append(conn.info.backend_pid) + + t.join() + p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + conn.execute("reset all") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +def test_no_queue_timeout(deaf_port): + with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p: + with pytest.raises(PoolTimeout): + with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue(dsn): + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + with NullConnectionPool(dsn, max_size=2) as p: + p.wait() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +def test_queue_size(dsn): + def worker(t, ev=None): + try: + with p.connection(): + if ev: + ev.set() + sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + p.wait() + ev = Event() + t = Thread(target=worker, args=(0.3, ev)) + t.start() + ev.wait() + + ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout(dsn): + def worker(n): + t0 = time() + try: + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +def test_dead_client(dsn): + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + results: List[int] = [] + + with NullConnectionPool(dsn, max_size=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout_override(dsn): + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +def test_broken_reconnect(dsn): + with NullConnectionPool(dsn, max_size=1) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + conn.close() + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + assert not conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + t = Thread(target=worker, args=(p,)) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +def test_close_no_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner and p._sched_runner.is_alive() + workers = p._workers[:] + assert workers + for t in workers: + assert t.is_alive() + + p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert not t.is_alive() + + +def test_putconn_no_pool(conn_cls, dsn): + with NullConnectionPool(dsn) as p: + conn = conn_cls.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) + + conn.close() + + +def test_putconn_wrong_pool(dsn): + with NullConnectionPool(dsn) as p1: + with NullConnectionPool(dsn) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) + + +@pytest.mark.slow +def test_del_stop_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + sleep(0.1) + for t in ts: + assert not t.is_alive() + + +def test_closed_getconn(dsn): + p = NullConnectionPool(dsn) + assert not p.closed + with p.connection(): + pass + + p.close() + assert p.closed + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + +def test_closed_putconn(dsn): + p = NullConnectionPool(dsn) + + with p.connection() as conn: + pass + assert conn.closed + + with p.connection() as conn: + p.close() + assert conn.closed + + +def test_closed_queue(dsn): + def w1(): + with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + def w2(): + try: + with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = Event() + e2 = Event() + + p = NullConnectionPool(dsn, max_size=1) + p.wait() + success: List[str] = [] + + t1 = Thread(target=w1) + t1.start() + # Wait until w1 has received a connection + e1.wait() + + t2 = Thread(target=w2) + t2.start() + # Wait until w2 is in the queue + ensure_waiting(p) + + p.close(0) + + # Wait for the workers to finish + e2.set() + t1.join() + t2.join() + assert len(success) == 2 + + +def test_open_explicit(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed, match="is not open yet"): + p.getconn() + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + p.open() + try: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + p.getconn() + + +def test_open_context(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + + with p: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + assert p.closed + + +def test_open_no_op(dsn): + p = NullConnectionPool(dsn) + try: + assert not p.closed + p.open() + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + +def test_reopen(dsn): + p = NullConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + p.open() + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +def test_bad_resize(dsn, min_size, max_size): + with NullConnectionPool() as p: + with pytest.raises(ValueError): + p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_max_lifetime(dsn): + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + sleep(0.1) + + ts = [] + with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + for i in range(5): + ts.append(Thread(target=worker, args=(p,))) + ts[-1].start() + + for t in ts: + t.join() + + assert pids[0] == pids[1] != pids[4], pids + + +def test_check(dsn): + with NullConnectionPool(dsn) as p: + # No-op + p.check() + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_measures(dsn): + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + + with NullConnectionPool(dsn, max_size=4) as p: + p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [Thread(target=worker, args=(i,)) for i in range(3)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + p.wait(2.0) + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_usage(dsn): + def worker(n): + try: + with p.connection(timeout=0.3) as conn: + conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + with NullConnectionPool(dsn, max_size=3) as p: + p.wait(2.0) + + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + for t in ts: + t.join() + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + with p.connection() as conn: + conn.close() + p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + with NullConnectionPool(proxy.client_dsn, max_size=3) as p: + p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py new file mode 100644 index 0000000..23a1a52 --- /dev/null +++ b/tests/pool/test_null_pool_async.py @@ -0,0 +1,844 @@ +import asyncio +import logging +from time import time +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import create_task +from .test_pool_async import delay_connection, ensure_waiting + +pytestmark = [pytest.mark.asyncio] + +try: + from psycopg_pool import AsyncNullConnectionPool # noqa: F401 + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +async def test_defaults(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +async def test_min_size_max_size(dsn): + async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +async def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + AsyncNullConnectionPool(min_size=min_size, max_size=max_size) + + +async def test_connection_class(dsn): + class MyConn(psycopg.AsyncConnection[Any]): + pass + + async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + +async def test_kwargs(dsn): + async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + async with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +async def test_its_no_pool_at_all(dsn): + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + async with p.connection() as conn: + assert conn.info.backend_pid not in (pid1, pid2) + + +async def test_context(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.1) + + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.4) + + +async def test_wait_closed(dsn): + async with AsyncNullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + await p.wait() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p: + await p.wait(0.2) + + async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + +async def test_configure(dsn): + inits = 0 + + async def configure(conn): + nonlocal inits + inits += 1 + async with conn.transaction(): + await conn.execute("set default_transaction_read_only to on") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + async with p.connection() as conn: + assert inits == 2 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + await conn.close() + + async with p.connection() as conn: + assert inits == 3 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +async def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + await conn.execute("select 1") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +async def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + pids = [] + + async def worker(): + async with p.connection() as conn: + assert resets == 1 + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = create_task(worker()) + await ensure_waiting(p) + + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + await p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + await conn.execute("reset all") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + await ensure_waiting(p) + + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + await ensure_waiting(p) + + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +async def test_no_queue_timeout(deaf_port): + async with AsyncNullConnectionPool( + kwargs={"host": "localhost", "port": deaf_port} + ) as p: + with pytest.raises(PoolTimeout): + async with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue(dsn): + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + await p.wait() + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_size(dsn): + async def worker(t, ev=None): + try: + async with p.connection(): + if ev: + ev.set() + await asyncio.sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + await p.wait() + ev = asyncio.Event() + create_task(worker(0.3, ev)) + await ev.wait() + + ts = [create_task(worker(0.1)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout(dsn): + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_dead_client(dsn): + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + results: List[int] = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout_override(dsn): + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +async def test_broken_reconnect(dsn): + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + await conn.close() + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + cur = await conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + await conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +async def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + t = create_task(worker()) + await ensure_waiting(p) + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +async def test_close_no_tasks(dsn): + p = AsyncNullConnectionPool(dsn) + assert p._sched_runner and not p._sched_runner.done() + assert p._workers + workers = p._workers[:] + for t in workers: + assert not t.done() + + await p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert t.done() + + +async def test_putconn_no_pool(aconn_cls, dsn): + async with AsyncNullConnectionPool(dsn) as p: + conn = await aconn_cls.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + + await conn.close() + + +async def test_putconn_wrong_pool(dsn): + async with AsyncNullConnectionPool(dsn) as p1: + async with AsyncNullConnectionPool(dsn) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + + +async def test_closed_getconn(dsn): + p = AsyncNullConnectionPool(dsn) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = AsyncNullConnectionPool(dsn) + + async with p.connection() as conn: + pass + assert conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +async def test_closed_queue(dsn): + async def w1(): + async with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + await e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + async def w2(): + try: + async with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = asyncio.Event() + e2 = asyncio.Event() + + p = AsyncNullConnectionPool(dsn, max_size=1) + await p.wait() + success: List[str] = [] + + t1 = create_task(w1()) + # Wait until w1 has received a connection + await e1.wait() + + t2 = create_task(w2()) + # Wait until w2 is in the queue + await ensure_waiting(p) + await p.close() + + # Wait for the workers to finish + e2.set() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +async def test_open_explicit(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed): + await p.getconn() + + with pytest.raises(PoolClosed, match="is not open yet"): + async with p.connection(): + pass + + await p.open() + try: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + await p.getconn() + + +async def test_open_context(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + + async with p: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + assert p.closed + + +async def test_open_no_op(dsn): + p = AsyncNullConnectionPool(dsn) + try: + assert not p.closed + await p.open() + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + +async def test_reopen(dsn): + p = AsyncNullConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + await p.open() + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +async def test_bad_resize(dsn, min_size, max_size): + async with AsyncNullConnectionPool() as p: + with pytest.raises(ValueError): + await p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_max_lifetime(dsn): + pids: List[int] = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + await asyncio.sleep(0.1) + + async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + ts = [create_task(worker()) for i in range(5)] + await asyncio.gather(*ts) + + assert pids[0] == pids[1] != pids[4], pids + + +async def test_check(dsn): + # no.op + async with AsyncNullConnectionPool(dsn) as p: + await p.check() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_measures(dsn): + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + + async with AsyncNullConnectionPool(dsn, max_size=4) as p: + await p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [create_task(worker(i)) for i in range(3)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + await p.wait(2.0) + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_usage(dsn): + async def worker(n): + try: + async with p.connection(timeout=0.3) as conn: + await conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + async with AsyncNullConnectionPool(dsn, max_size=3) as p: + await p.wait(2.0) + + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.gather(*ts) + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + async with p.connection() as conn: + await conn.close() + await p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + async with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +async def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p: + await p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py new file mode 100644 index 0000000..30c790b --- /dev/null +++ b/tests/pool/test_pool.py @@ -0,0 +1,1265 @@ +import logging +import weakref +from time import sleep, time +from threading import Thread, Event +from typing import Any, List, Tuple + +import pytest + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import Counter + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + + +def test_package_version(mypy): + cp = mypy.run_on_source( + """\ +from psycopg_pool import __version__ +assert __version__ +""" + ) + assert not cp.stdout + + +def test_defaults(dsn): + with pool.ConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 4 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)]) +def test_min_size_max_size(dsn, min_size, max_size): + with pool.ConnectionPool(dsn, min_size=min_size, max_size=max_size) as p: + assert p.min_size == min_size + assert p.max_size == max_size if max_size is not None else min_size + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]) +def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + pool.ConnectionPool(min_size=min_size, max_size=max_size) + + +def test_connection_class(dsn): + class MyConn(psycopg.Connection[Any]): + pass + + with pool.ConnectionPool(dsn, connection_class=MyConn, min_size=1) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + +def test_kwargs(dsn): + with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, min_size=1) as p: + with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +def test_its_really_a_pool(dsn): + with pool.ConnectionPool(dsn, min_size=2) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + with p.connection() as conn: + assert conn.info.backend_pid in (pid1, pid2) + + +def test_context(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.crdb_skip("backend pid") +def test_connection_not_lost(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + with pytest.raises(ZeroDivisionError): + with p.connection() as conn: + pid = conn.info.backend_pid + 1 / 0 + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + + +@pytest.mark.slow +@pytest.mark.timing +def test_concurrent_filling(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + + def add_time(self, conn): + times.append(time() - t0) + add_orig(self, conn) + + add_orig = pool.ConnectionPool._add_to_pool + monkeypatch.setattr(pool.ConnectionPool, "_add_to_pool", add_time) + + times: List[float] = [] + t0 = time() + + with pool.ConnectionPool(dsn, min_size=5, num_workers=2) as p: + p.wait(1.0) + want_times = [0.1, 0.1, 0.2, 0.2, 0.3] + assert len(times) == len(want_times) + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.wait(0.3) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.wait(0.5) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=2) as p: + p.wait(0.3) + p.wait(0.0001) # idempotent + + +def test_wait_closed(dsn): + with pool.ConnectionPool(dsn) as p: + pass + + with pytest.raises(pool.PoolClosed): + p.wait() + + +@pytest.mark.slow +def test_setup_no_timeout(dsn, proxy): + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p: + p.wait(0.2) + + with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() + + with p.connection() as conn: + conn.execute("select 1") + + +def test_configure(dsn): + inits = 0 + + def configure(conn): + nonlocal inits + inits += 1 + with conn.transaction(): + conn.execute("set default_transaction_read_only to on") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + p.wait() + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + conn.close() + + with p.connection() as conn: + assert inits == 2 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + conn.execute("select 1") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + with conn.transaction(): + conn.execute("WAT") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + assert resets == 0 + conn.execute("set timezone to '+2:00'") + + p.wait() + assert resets == 1 + + with p.connection() as conn: + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + + p.wait() + assert resets == 2 + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + conn.execute("reset all") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.info.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.info.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue(dsn): + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + with pool.ConnectionPool(dsn, min_size=2) as p: + p.wait() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +def test_queue_size(dsn): + def worker(t, ev=None): + try: + with p.connection(): + if ev: + ev.set() + sleep(t) + except pool.TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + with pool.ConnectionPool(dsn, min_size=1, max_waiting=3) as p: + p.wait() + ev = Event() + t = Thread(target=worker, args=(0.3, ev)) + t.start() + ev.wait() + + ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], pool.TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout(dsn): + def worker(n): + t0 = time() + try: + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +def test_dead_client(dsn): + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + results: List[int] = [] + + with pool.ConnectionPool(dsn, min_size=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout_override(dsn): + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +def test_broken_reconnect(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + conn.close() + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + assert not conn2.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() + + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +def test_close_no_threads(dsn): + p = pool.ConnectionPool(dsn) + assert p._sched_runner and p._sched_runner.is_alive() + workers = p._workers[:] + assert workers + for t in workers: + assert t.is_alive() + + p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert not t.is_alive() + + +def test_putconn_no_pool(conn_cls, dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = conn_cls.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) + + conn.close() + + +def test_putconn_wrong_pool(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p1: + with pool.ConnectionPool(dsn, min_size=1) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) + + +def test_del_no_warning(dsn, recwarn): + p = pool.ConnectionPool(dsn, min_size=2) + with p.connection() as conn: + conn.execute("select 1") + + p.wait() + ref = weakref.ref(p) + del p + assert not ref() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.mark.slow +def test_del_stop_threads(dsn): + p = pool.ConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + sleep(0.1) + for t in ts: + assert not t.is_alive() + + +def test_closed_getconn(dsn): + p = pool.ConnectionPool(dsn, min_size=1) + assert not p.closed + with p.connection(): + pass + + p.close() + assert p.closed + + with pytest.raises(pool.PoolClosed): + with p.connection(): + pass + + +def test_closed_putconn(dsn): + p = pool.ConnectionPool(dsn, min_size=1) + + with p.connection() as conn: + pass + assert not conn.closed + + with p.connection() as conn: + p.close() + assert conn.closed + + +def test_closed_queue(dsn): + def w1(): + with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + def w2(): + try: + with p.connection(): + pass # unexpected + except pool.PoolClosed: + success.append("w2") + + e1 = Event() + e2 = Event() + + p = pool.ConnectionPool(dsn, min_size=1) + p.wait() + success: List[str] = [] + + t1 = Thread(target=w1) + t1.start() + # Wait until w1 has received a connection + e1.wait() + + t2 = Thread(target=w2) + t2.start() + # Wait until w2 is in the queue + ensure_waiting(p) + + p.close(0) + + # Wait for the workers to finish + e2.set() + t1.join() + t2.join() + assert len(success) == 2 + + +def test_open_explicit(dsn): + p = pool.ConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(pool.PoolClosed, match="is not open yet"): + p.getconn() + + with pytest.raises(pool.PoolClosed): + with p.connection(): + pass + + p.open() + try: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + with pytest.raises(pool.PoolClosed, match="is already closed"): + p.getconn() + + +def test_open_context(dsn): + p = pool.ConnectionPool(dsn, open=False) + assert p.closed + + with p: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + assert p.closed + + +def test_open_no_op(dsn): + p = pool.ConnectionPool(dsn) + try: + assert not p.closed + p.open() + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + +@pytest.mark.slow +@pytest.mark.timing +def test_open_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + p.open(wait=True, timeout=0.3) + finally: + p.close() + + p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + p.open(wait=True, timeout=0.5) + finally: + p.close() + + +@pytest.mark.slow +@pytest.mark.timing +def test_open_as_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.open(wait=True, timeout=0.3) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.open(wait=True, timeout=0.5) + + +def test_reopen(dsn): + p = pool.ConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + p.open() + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.parametrize( + "min_size, want_times", + [ + (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]), + (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]), + ], +) +def test_grow(dsn, monkeypatch, min_size, want_times): + delay_connection(monkeypatch, 0.1) + + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select 1 from pg_sleep(0.25)") + t1 = time() + results.append((n, t1 - t0)) + + with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p: + p.wait(1.0) + results: List[Tuple[int, float]] = [] + + ts = [Thread(target=worker, args=(i,)) for i in range(len(want_times))] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +def test_shrink(dsn, monkeypatch): + + from psycopg_pool.pool import ShrinkPool + + results: List[Tuple[int, int]] = [] + + def run_hacked(self, pool): + n0 = pool._nconns + orig_run(self, pool) + n1 = pool._nconns + results.append((n0, n1)) + + orig_run = ShrinkPool._run + monkeypatch.setattr(ShrinkPool, "_run", run_hacked) + + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.1)") + + with pool.ConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p: + p.wait(5.0) + assert p.max_idle == 0.2 + + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(1) + + assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] + + +@pytest.mark.slow +def test_reconnect(proxy, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) + + caplog.clear() + proxy.start() + with pool.ConnectionPool(proxy.client_dsn, min_size=1) as p: + p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + with p.connection() as conn: + conn.execute("select 1") + + sleep(1.0) + proxy.start() + p.wait() + + with p.connection() as conn: + conn.execute("select 1") + + assert "BAD" in caplog.messages[0] + times = [rec.created for rec in caplog.records] + assert times[1] - times[0] < 0.05 + deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)] + assert len(deltas) == 3 + want = 0.1 + for delta in deltas: + assert delta == pytest.approx(want, 0.05), deltas + want *= 2 + + +@pytest.mark.slow +@pytest.mark.timing +def test_reconnect_failure(proxy): + proxy.start() + + t1 = None + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + with pool.ConnectionPool( + proxy.client_dsn, + name="this-one", + min_size=1, + reconnect_timeout=1.0, + reconnect_failed=failed, + ) as p: + p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + with p.connection() as conn: + conn.execute("select 1") + + t0 = time() + sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 + + proxy.start() + t0 = time() + with p.connection() as conn: + conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 + + +@pytest.mark.slow +def test_reconnect_after_grow_failed(proxy): + # Retry reconnection after a failed connection attempt has put the pool + # in grow mode. See issue #370. + proxy.stop() + + ev = Event() + + def failed(pool): + ev.set() + + with pool.ConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + assert ev.wait(timeout=2) + + with pytest.raises(pool.PoolTimeout): + with p.connection(timeout=0.5) as conn: + pass + + ev.clear() + assert ev.wait(timeout=2) + + proxy.start() + + with p.connection(timeout=2) as conn: + conn.execute("select 1") + + p.wait(timeout=3.0) + assert len(p._pool) == p.min_size == 4 + + +@pytest.mark.slow +def test_refill_on_check(proxy): + proxy.start() + ev = Event() + + def failed(pool): + ev.set() + + with pool.ConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + # The pool is full + p.wait(timeout=2) + + # Break all the connection + proxy.stop() + + # Checking the pool will empty it + p.check() + assert ev.wait(timeout=2) + assert len(p._pool) == 0 + + # Allow to connect again + proxy.start() + + # Make sure that check has refilled the pool + p.check() + p.wait(timeout=2) + assert len(p._pool) == 4 + + +@pytest.mark.slow +def test_uniform_use(dsn): + with pool.ConnectionPool(dsn, min_size=4) as p: + counts = Counter[int]() + for i in range(8): + with p.connection() as conn: + sleep(0.1) + counts[id(conn)] += 1 + + assert len(counts) == 4 + assert set(counts.values()) == set([2]) + + +@pytest.mark.slow +@pytest.mark.timing +def test_resize(dsn): + def sampler(): + sleep(0.05) # ensure sampling happens after shrink check + while True: + sleep(0.2) + if p.closed: + break + size.append(len(p._pool)) + + def client(t): + with p.connection() as conn: + conn.execute("select pg_sleep(%s)", [t]) + + size: List[int] = [] + + with pool.ConnectionPool(dsn, min_size=2, max_idle=0.2) as p: + s = Thread(target=sampler) + s.start() + + sleep(0.3) + c = Thread(target=client, args=(0.4,)) + c.start() + + sleep(0.2) + p.resize(4) + assert p.min_size == 4 + assert p.max_size == 4 + + sleep(0.4) + p.resize(2) + assert p.min_size == 2 + assert p.max_size == 2 + + sleep(0.6) + + s.join() + assert size == [2, 1, 3, 4, 3, 2, 2] + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)]) +def test_bad_resize(dsn, min_size, max_size): + with pool.ConnectionPool() as p: + with pytest.raises(ValueError): + p.resize(min_size=min_size, max_size=max_size) + + +def test_jitter(): + rnds = [pool.ConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)] + assert 27 <= min(rnds) <= 28 + assert 35 < max(rnds) < 36 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_max_lifetime(dsn): + with pool.ConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p: + sleep(0.1) + pids = [] + for i in range(5): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + sleep(0.2) + + assert pids[0] == pids[1] != pids[4], pids + + +@pytest.mark.crdb_skip("backend pid") +def test_check(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + with pool.ConnectionPool(dsn, min_size=4) as p: + p.wait(1.0) + with p.connection() as conn: + pid = conn.info.backend_pid + + p.wait(1.0) + pids = set(conn.info.backend_pid for conn in p._pool) + assert pid in pids + conn.close() + + assert len(caplog.records) == 0 + p.check() + assert len(caplog.records) == 1 + p.wait(1.0) + pids2 = set(conn.info.backend_pid for conn in p._pool) + assert len(pids & pids2) == 3 + assert pid not in pids2 + + +def test_check_idle(dsn): + with pool.ConnectionPool(dsn, min_size=2) as p: + p.wait(1.0) + p.check() + with p.connection() as conn: + assert conn.info.transaction_status == TransactionStatus.IDLE + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_measures(dsn): + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + + with pool.ConnectionPool(dsn, min_size=2, max_size=4) as p: + p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 2 + assert stats["pool_available"] == 2 + assert stats["requests_waiting"] == 0 + + ts = [Thread(target=worker, args=(i,)) for i in range(3)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + p.wait(2.0) + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_usage(dsn): + def worker(n): + try: + with p.connection(timeout=0.3) as conn: + conn.execute("select pg_sleep(0.2)") + except pool.PoolTimeout: + pass + + with pool.ConnectionPool(dsn, min_size=3) as p: + p.wait(2.0) + + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + for t in ts: + t.join() + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + with p.connection() as conn: + conn.close() + p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + with pool.ConnectionPool(proxy.client_dsn, min_size=3) as p: + p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 3 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 600 <= stats["connections_ms"] < 1200 + + proxy.stop() + p.check() + sleep(0.1) + stats = p.get_stats() + assert stats["connections_num"] > 3 + assert stats["connections_errors"] > 0 + assert stats["connections_lost"] == 3 + + +@pytest.mark.slow +def test_spike(dsn, monkeypatch): + # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/ + # documents/Welcome-To-The-Jungle.md + delay_connection(monkeypatch, 0.15) + + def worker(): + with p.connection(): + sleep(0.002) + + with pool.ConnectionPool(dsn, min_size=5, max_size=10) as p: + p.wait() + + ts = [Thread(target=worker) for i in range(50)] + for t in ts: + t.start() + for t in ts: + t.join() + p.wait() + + assert len(p._pool) < 7 + + +def test_debug_deadlock(dsn): + # https://github.com/psycopg/psycopg/issues/230 + logger = logging.getLogger("psycopg") + handler = logging.StreamHandler() + old_level = logger.level + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + try: + with pool.ConnectionPool(dsn, min_size=4, open=True) as p: + try: + p.wait(timeout=2) + finally: + print(p.get_stats()) + finally: + logger.removeHandler(handler) + logger.setLevel(old_level) + + +def delay_connection(monkeypatch, sec): + """ + Return a _connect_gen function delayed by the amount of seconds + """ + + def connect_delay(*args, **kwargs): + t0 = time() + rv = connect_orig(*args, **kwargs) + t1 = time() + sleep(max(0, sec - (t1 - t0))) + return rv + + connect_orig = psycopg.Connection.connect + monkeypatch.setattr(psycopg.Connection, "connect", connect_delay) + + +def ensure_waiting(p, num=1): + """ + Wait until there are at least *num* clients waiting in the queue. + """ + while len(p._waiting) < num: + sleep(0) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py new file mode 100644 index 0000000..286a775 --- /dev/null +++ b/tests/pool/test_pool_async.py @@ -0,0 +1,1198 @@ +import asyncio +import logging +from time import time +from typing import Any, List, Tuple + +import pytest + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import create_task, Counter + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.asyncio] + + +async def test_defaults(dsn): + async with pool.AsyncConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 4 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)]) +async def test_min_size_max_size(dsn, min_size, max_size): + async with pool.AsyncConnectionPool(dsn, min_size=min_size, max_size=max_size) as p: + assert p.min_size == min_size + assert p.max_size == max_size if max_size is not None else min_size + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]) +async def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + pool.AsyncConnectionPool(min_size=min_size, max_size=max_size) + + +async def test_connection_class(dsn): + class MyConn(psycopg.AsyncConnection[Any]): + pass + + async with pool.AsyncConnectionPool(dsn, connection_class=MyConn, min_size=1) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + +async def test_kwargs(dsn): + async with pool.AsyncConnectionPool( + dsn, kwargs={"autocommit": True}, min_size=1 + ) as p: + async with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +async def test_its_really_a_pool(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + async with p.connection() as conn: + assert conn.info.backend_pid in (pid1, pid2) + + +async def test_context(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.crdb_skip("backend pid") +async def test_connection_not_lost(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + with pytest.raises(ZeroDivisionError): + async with p.connection() as conn: + pid = conn.info.backend_pid + 1 / 0 + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + + +@pytest.mark.slow +@pytest.mark.timing +async def test_concurrent_filling(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + + async def add_time(self, conn): + times.append(time() - t0) + await add_orig(self, conn) + + add_orig = pool.AsyncConnectionPool._add_to_pool + monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time) + + times: List[float] = [] + t0 = time() + + async with pool.AsyncConnectionPool(dsn, min_size=5, num_workers=2) as p: + await p.wait(1.0) + want_times = [0.1, 0.1, 0.2, 0.2, 0.3] + assert len(times) == len(want_times) + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.wait(0.3) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.wait(0.5) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=2) as p: + await p.wait(0.3) + await p.wait(0.0001) # idempotent + + +async def test_wait_closed(dsn): + async with pool.AsyncConnectionPool(dsn) as p: + pass + + with pytest.raises(pool.PoolClosed): + await p.wait() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=1, num_workers=1 + ) as p: + await p.wait(0.2) + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=1, num_workers=1 + ) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + +async def test_configure(dsn): + inits = 0 + + async def configure(conn): + nonlocal inits + inits += 1 + async with conn.transaction(): + await conn.execute("set default_transaction_read_only to on") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + await p.wait(timeout=1.0) + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + await conn.close() + + async with p.connection() as conn: + assert inits == 2 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +async def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + await conn.execute("select 1") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +async def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + + await p.wait() + assert resets == 1 + + async with p.connection() as conn: + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + + await p.wait() + assert resets == 2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + await conn.execute("reset all") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.info.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.info.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue(dsn): + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + await p.wait() + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_size(dsn): + async def worker(t, ev=None): + try: + async with p.connection(): + if ev: + ev.set() + await asyncio.sleep(t) + except pool.TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=1, max_waiting=3) as p: + await p.wait() + ev = asyncio.Event() + create_task(worker(0.3, ev)) + await ev.wait() + + ts = [create_task(worker(0.1)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], pool.TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout(dsn): + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_dead_client(dsn): + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + results: List[int] = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout_override(dsn): + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +async def test_broken_reconnect(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + await conn.close() + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + await conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + cur = await conn2.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +async def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +async def test_close_no_tasks(dsn): + p = pool.AsyncConnectionPool(dsn) + assert p._sched_runner and not p._sched_runner.done() + assert p._workers + workers = p._workers[:] + for t in workers: + assert not t.done() + + await p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert t.done() + + +async def test_putconn_no_pool(aconn_cls, dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await aconn_cls.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + + await conn.close() + + +async def test_putconn_wrong_pool(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p1: + async with pool.AsyncConnectionPool(dsn, min_size=1) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + + +async def test_closed_getconn(dsn): + p = pool.AsyncConnectionPool(dsn, min_size=1) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(pool.PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = pool.AsyncConnectionPool(dsn, min_size=1) + + async with p.connection() as conn: + pass + assert not conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +async def test_closed_queue(dsn): + async def w1(): + async with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + await e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + async def w2(): + try: + async with p.connection(): + pass # unexpected + except pool.PoolClosed: + success.append("w2") + + e1 = asyncio.Event() + e2 = asyncio.Event() + + p = pool.AsyncConnectionPool(dsn, min_size=1) + await p.wait() + success: List[str] = [] + + t1 = create_task(w1()) + # Wait until w1 has received a connection + await e1.wait() + + t2 = create_task(w2()) + # Wait until w2 is in the queue + await ensure_waiting(p) + await p.close() + + # Wait for the workers to finish + e2.set() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +async def test_open_explicit(dsn): + p = pool.AsyncConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(pool.PoolClosed): + await p.getconn() + + with pytest.raises(pool.PoolClosed, match="is not open yet"): + async with p.connection(): + pass + + await p.open() + try: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + with pytest.raises(pool.PoolClosed, match="is already closed"): + await p.getconn() + + +async def test_open_context(dsn): + p = pool.AsyncConnectionPool(dsn, open=False) + assert p.closed + + async with p: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + assert p.closed + + +async def test_open_no_op(dsn): + p = pool.AsyncConnectionPool(dsn) + try: + assert not p.closed + await p.open() + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_open_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + await p.open(wait=True, timeout=0.3) + finally: + await p.close() + + p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + await p.open(wait=True, timeout=0.5) + finally: + await p.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_open_as_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.open(wait=True, timeout=0.3) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.open(wait=True, timeout=0.5) + + +async def test_reopen(dsn): + p = pool.AsyncConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + await p.open() + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.parametrize( + "min_size, want_times", + [ + (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]), + (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]), + ], +) +async def test_grow(dsn, monkeypatch, min_size, want_times): + delay_connection(monkeypatch, 0.1) + + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1 from pg_sleep(0.25)") + t1 = time() + results.append((n, t1 - t0)) + + async with pool.AsyncConnectionPool( + dsn, min_size=min_size, max_size=4, num_workers=3 + ) as p: + await p.wait(1.0) + ts = [] + results: List[Tuple[int, float]] = [] + + ts = [create_task(worker(i)) for i in range(len(want_times))] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +async def test_shrink(dsn, monkeypatch): + + from psycopg_pool.pool_async import ShrinkPool + + results: List[Tuple[int, int]] = [] + + async def run_hacked(self, pool): + n0 = pool._nconns + await orig_run(self, pool) + n1 = pool._nconns + results.append((n0, n1)) + + orig_run = ShrinkPool._run + monkeypatch.setattr(ShrinkPool, "_run", run_hacked) + + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.1)") + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p: + await p.wait(5.0) + assert p.max_idle == 0.2 + + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + await asyncio.sleep(1) + + assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] + + +@pytest.mark.slow +async def test_reconnect(proxy, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) + + caplog.clear() + proxy.start() + async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=1) as p: + await p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + await asyncio.sleep(1.0) + proxy.start() + await p.wait() + + async with p.connection() as conn: + await conn.execute("select 1") + + assert "BAD" in caplog.messages[0] + times = [rec.created for rec in caplog.records] + assert times[1] - times[0] < 0.05 + deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)] + assert len(deltas) == 3 + want = 0.1 + for delta in deltas: + assert delta == pytest.approx(want, 0.05), deltas + want *= 2 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_reconnect_failure(proxy): + proxy.start() + + t1 = None + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, + name="this-one", + min_size=1, + reconnect_timeout=1.0, + reconnect_failed=failed, + ) as p: + await p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + t0 = time() + await asyncio.sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 + + proxy.start() + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 + + +@pytest.mark.slow +async def test_reconnect_after_grow_failed(proxy): + # Retry reconnection after a failed connection attempt has put the pool + # in grow mode. See issue #370. + proxy.stop() + + ev = asyncio.Event() + + def failed(pool): + ev.set() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + await asyncio.wait_for(ev.wait(), 2.0) + + with pytest.raises(pool.PoolTimeout): + async with p.connection(timeout=0.5) as conn: + pass + + ev.clear() + await asyncio.wait_for(ev.wait(), 2.0) + + proxy.start() + + async with p.connection(timeout=2) as conn: + await conn.execute("select 1") + + await p.wait(timeout=3.0) + assert len(p._pool) == p.min_size == 4 + + +@pytest.mark.slow +async def test_refill_on_check(proxy): + proxy.start() + ev = asyncio.Event() + + def failed(pool): + ev.set() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + # The pool is full + await p.wait(timeout=2) + + # Break all the connection + proxy.stop() + + # Checking the pool will empty it + await p.check() + await asyncio.wait_for(ev.wait(), 2.0) + assert len(p._pool) == 0 + + # Allow to connect again + proxy.start() + + # Make sure that check has refilled the pool + await p.check() + await p.wait(timeout=2) + assert len(p._pool) == 4 + + +@pytest.mark.slow +async def test_uniform_use(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=4) as p: + counts = Counter[int]() + for i in range(8): + async with p.connection() as conn: + await asyncio.sleep(0.1) + counts[id(conn)] += 1 + + assert len(counts) == 4 + assert set(counts.values()) == set([2]) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_resize(dsn): + async def sampler(): + await asyncio.sleep(0.05) # ensure sampling happens after shrink check + while True: + await asyncio.sleep(0.2) + if p.closed: + break + size.append(len(p._pool)) + + async def client(t): + async with p.connection() as conn: + await conn.execute("select pg_sleep(%s)", [t]) + + size: List[int] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_idle=0.2) as p: + s = create_task(sampler()) + + await asyncio.sleep(0.3) + + c = create_task(client(0.4)) + + await asyncio.sleep(0.2) + await p.resize(4) + assert p.min_size == 4 + assert p.max_size == 4 + + await asyncio.sleep(0.4) + await p.resize(2) + assert p.min_size == 2 + assert p.max_size == 2 + + await asyncio.sleep(0.6) + + await asyncio.gather(s, c) + assert size == [2, 1, 3, 4, 3, 2, 2] + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)]) +async def test_bad_resize(dsn, min_size, max_size): + async with pool.AsyncConnectionPool() as p: + with pytest.raises(ValueError): + await p.resize(min_size=min_size, max_size=max_size) + + +async def test_jitter(): + rnds = [pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)] + assert 27 <= min(rnds) <= 28 + assert 35 < max(rnds) < 36 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_max_lifetime(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p: + await asyncio.sleep(0.1) + pids = [] + for i in range(5): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + await asyncio.sleep(0.2) + + assert pids[0] == pids[1] != pids[4], pids + + +@pytest.mark.crdb_skip("backend pid") +async def test_check(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + async with pool.AsyncConnectionPool(dsn, min_size=4) as p: + await p.wait(1.0) + async with p.connection() as conn: + pid = conn.info.backend_pid + + await p.wait(1.0) + pids = set(conn.info.backend_pid for conn in p._pool) + assert pid in pids + await conn.close() + + assert len(caplog.records) == 0 + await p.check() + assert len(caplog.records) == 1 + await p.wait(1.0) + pids2 = set(conn.info.backend_pid for conn in p._pool) + assert len(pids & pids2) == 3 + assert pid not in pids2 + + +async def test_check_idle(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + await p.wait(1.0) + await p.check() + async with p.connection() as conn: + assert conn.info.transaction_status == TransactionStatus.IDLE + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_measures(dsn): + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4) as p: + await p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 2 + assert stats["pool_available"] == 2 + assert stats["requests_waiting"] == 0 + + ts = [create_task(worker(i)) for i in range(3)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + await p.wait(2.0) + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_usage(dsn): + async def worker(n): + try: + async with p.connection(timeout=0.3) as conn: + await conn.execute("select pg_sleep(0.2)") + except pool.PoolTimeout: + pass + + async with pool.AsyncConnectionPool(dsn, min_size=3) as p: + await p.wait(2.0) + + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.gather(*ts) + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + async with p.connection() as conn: + await conn.close() + await p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + async with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +async def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=3) as p: + await p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 3 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 580 <= stats["connections_ms"] < 1200 + + proxy.stop() + await p.check() + await asyncio.sleep(0.1) + stats = p.get_stats() + assert stats["connections_num"] > 3 + assert stats["connections_errors"] > 0 + assert stats["connections_lost"] == 3 + + +@pytest.mark.slow +async def test_spike(dsn, monkeypatch): + # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/ + # documents/Welcome-To-The-Jungle.md + delay_connection(monkeypatch, 0.15) + + async def worker(): + async with p.connection(): + await asyncio.sleep(0.002) + + async with pool.AsyncConnectionPool(dsn, min_size=5, max_size=10) as p: + await p.wait() + + ts = [create_task(worker()) for i in range(50)] + await asyncio.gather(*ts) + await p.wait() + + assert len(p._pool) < 7 + + +async def test_debug_deadlock(dsn): + # https://github.com/psycopg/psycopg/issues/230 + logger = logging.getLogger("psycopg") + handler = logging.StreamHandler() + old_level = logger.level + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + try: + async with pool.AsyncConnectionPool(dsn, min_size=4, open=True) as p: + await p.wait(timeout=2) + finally: + logger.removeHandler(handler) + logger.setLevel(old_level) + + +def delay_connection(monkeypatch, sec): + """ + Return a _connect_gen function delayed by the amount of seconds + """ + + async def connect_delay(*args, **kwargs): + t0 = time() + rv = await connect_orig(*args, **kwargs) + t1 = time() + await asyncio.sleep(max(0, sec - (t1 - t0))) + return rv + + connect_orig = psycopg.AsyncConnection.connect + monkeypatch.setattr(psycopg.AsyncConnection, "connect", connect_delay) + + +async def ensure_waiting(p, num=1): + while len(p._waiting) < num: + await asyncio.sleep(0) diff --git a/tests/pool/test_pool_async_noasyncio.py b/tests/pool/test_pool_async_noasyncio.py new file mode 100644 index 0000000..f6e34e4 --- /dev/null +++ b/tests/pool/test_pool_async_noasyncio.py @@ -0,0 +1,78 @@ +# These tests relate to AsyncConnectionPool, but are not marked asyncio +# because they rely on the pool initialization outside the asyncio loop. + +import asyncio + +import pytest + +from ..utils import gc_collect + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + + +@pytest.mark.slow +def test_reconnect_after_max_lifetime(dsn, asyncio_run): + # See issue #219, pool created before the loop. + p = pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2, open=False) + + async def test(): + try: + await p.open() + ns = [] + for i in range(5): + async with p.connection() as conn: + cur = await conn.execute("select 1") + ns.append(await cur.fetchone()) + await asyncio.sleep(0.2) + assert len(ns) == 5 + finally: + await p.close() + + asyncio_run(asyncio.wait_for(test(), timeout=2.0)) + + +@pytest.mark.slow +def test_working_created_before_loop(dsn, asyncio_run): + p = pool.AsyncNullConnectionPool(dsn, open=False) + + async def test(): + try: + await p.open() + ns = [] + for i in range(5): + async with p.connection() as conn: + cur = await conn.execute("select 1") + ns.append(await cur.fetchone()) + await asyncio.sleep(0.2) + assert len(ns) == 5 + finally: + await p.close() + + asyncio_run(asyncio.wait_for(test(), timeout=2.0)) + + +def test_cant_create_open_outside_loop(dsn): + with pytest.raises(RuntimeError): + pool.AsyncConnectionPool(dsn, open=True) + + +@pytest.fixture +def asyncio_run(recwarn): + """Fixture reuturning asyncio.run, but managing resources at exit. + + In certain runs, fd objects are leaked and the error will only be caught + downstream, by some innocent test calling gc_collect(). + """ + recwarn.clear() + try: + yield asyncio.run + finally: + gc_collect() + if recwarn: + warn = recwarn.pop(ResourceWarning) + assert "unclosed event loop" in str(warn.message) + assert not recwarn diff --git a/tests/pool/test_sched.py b/tests/pool/test_sched.py new file mode 100644 index 0000000..b3d2572 --- /dev/null +++ b/tests/pool/test_sched.py @@ -0,0 +1,154 @@ +import logging +from time import time, sleep +from functools import partial +from threading import Thread + +import pytest + +try: + from psycopg_pool.sched import Scheduler +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.timing] + + +@pytest.mark.slow +def test_sched(): + s = Scheduler() + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +def test_sched_thread(): + s = Scheduler() + t = Thread(target=s.run, daemon=True) + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.2) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.2) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.2) + + +@pytest.mark.slow +def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + s = Scheduler() + t = Thread(target=s.run, daemon=True) + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + def error(): + 1 / 0 + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, None) + s.enter(0.3, partial(worker, 2)) + s.enter(0.2, error) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.1) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +def test_empty_queue_timeout(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = Thread(target=s.run) + t.start() + sleep(0.5) + s.enter(0.5, None) + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.2), times + + +@pytest.mark.slow +def test_first_task_rescheduling(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + s.enter(0.4, lambda: None) + t = Thread(target=s.run) + t.start() + s.enter(0.6, None) # this task doesn't trigger a reschedule + sleep(0.1) + s.enter(0.1, lambda: None) # this triggers a reschedule + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.2), times diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py new file mode 100644 index 0000000..492d620 --- /dev/null +++ b/tests/pool/test_sched_async.py @@ -0,0 +1,159 @@ +import asyncio +import logging +from time import time +from functools import partial + +import pytest + +from psycopg._compat import create_task + +try: + from psycopg_pool.sched import AsyncScheduler +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.asyncio, pytest.mark.timing] + + +@pytest.mark.slow +async def test_sched(): + s = AsyncScheduler() + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + await s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +async def test_sched_task(): + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.2) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.2) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.2) + + +@pytest.mark.slow +async def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + async def error(): + 1 / 0 + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, None) + await s.enter(0.3, partial(worker, 2)) + await s.enter(0.2, error) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.1) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +async def test_empty_queue_timeout(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = create_task(s.run()) + await asyncio.sleep(0.5) + await s.enter(0.5, None) + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.2), times + + +@pytest.mark.slow +async def test_first_task_rescheduling(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + async def noop(): + pass + + await s.enter(0.4, noop) + t = create_task(s.run()) + await s.enter(0.6, None) # this task doesn't trigger a reschedule + await asyncio.sleep(0.1) + await s.enter(0.1, noop) # this triggers a reschedule + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.2), times diff --git a/tests/pq/__init__.py b/tests/pq/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/pq/__init__.py diff --git a/tests/pq/test_async.py b/tests/pq/test_async.py new file mode 100644 index 0000000..2c3de98 --- /dev/null +++ b/tests/pq/test_async.py @@ -0,0 +1,210 @@ +from select import select + +import pytest + +import psycopg +from psycopg import pq +from psycopg.generators import execute + + +def execute_wait(pgconn): + return psycopg.waiting.wait(execute(pgconn), pgconn.socket) + + +def test_send_query(pgconn): + # This test shows how to process an async query in all its glory + pgconn.nonblocking = 1 + + # Long query to make sure we have to wait on send + pgconn.send_query( + b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;" + % (b"x" * 1_000_000) + ) + + # send loop + waited_on_send = 0 + while True: + f = pgconn.flush() + if f == 0: + break + + waited_on_send += 1 + + rl, wl, xl = select([pgconn.socket], [pgconn.socket], []) + assert not (rl and wl) + if wl: + continue # call flush again() + if rl: + pgconn.consume_input() + continue + + # TODO: this check is not reliable, it fails on travis sometimes + # assert waited_on_send + + # read loop + results = [] + while True: + pgconn.consume_input() + if pgconn.is_busy(): + select([pgconn.socket], [], []) + continue + res = pgconn.get_result() + if res is None: + break + assert res.status == pq.ExecStatus.TUPLES_OK + results.append(res) + + assert len(results) == 2 + assert results[0].nfields == 1 + assert results[0].fname(0) == b"f" + assert results[0].get_value(0, 0) == b"x" + assert results[1].nfields == 1 + assert results[1].fname(0) == b"foo" + assert results[1].get_value(0, 0) == b"1" + + +def test_send_query_compact_test(pgconn): + # Like the above test but use psycopg facilities for compactness + pgconn.send_query( + b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;" + % (b"x" * 1_000_000) + ) + results = execute_wait(pgconn) + + assert len(results) == 2 + assert results[0].nfields == 1 + assert results[0].fname(0) == b"f" + assert results[0].get_value(0, 0) == b"x" + assert results[1].nfields == 1 + assert results[1].fname(0) == b"foo" + assert results[1].get_value(0, 0) == b"1" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_query(b"select 1") + + +def test_single_row_mode(pgconn): + pgconn.send_query(b"select generate_series(1,2)") + pgconn.set_single_row_mode() + + results = execute_wait(pgconn) + assert len(results) == 3 + + res = results[0] + assert res.status == pq.ExecStatus.SINGLE_TUPLE + assert res.ntuples == 1 + assert res.get_value(0, 0) == b"1" + + res = results[1] + assert res.status == pq.ExecStatus.SINGLE_TUPLE + assert res.ntuples == 1 + assert res.get_value(0, 0) == b"2" + + res = results[2] + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.ntuples == 0 + + +def test_send_query_params(pgconn): + pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"]) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_query_params(b"select $1", [b"1"]) + + +def test_send_prepare(pgconn): + pgconn.send_prepare(b"prep", b"select $1::int + $2::int") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"prep", [b"3", b"5"]) + (res,) = execute_wait(pgconn) + assert res.get_value(0, 0) == b"8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_prepare(b"prep", b"select $1::int + $2::int") + with pytest.raises(psycopg.OperationalError): + pgconn.send_query_prepared(b"prep", [b"3", b"5"]) + + +def test_send_prepare_types(pgconn): + pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23]) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"prep", [b"3", b"5"]) + (res,) = execute_wait(pgconn) + assert res.get_value(0, 0) == b"8" + + +def test_send_prepared_binary_in(pgconn): + val = b"foo\00bar" + pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1]) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"3" + assert res.get_value(0, 1) == b"7" + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1]) + + +@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]) +def test_send_prepared_binary_out(pgconn, fmt, out): + val = b"foo\00bar" + pgconn.send_prepare(b"", b"select $1::bytea") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"", [val], param_formats=[1], result_format=fmt) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == out + + +def test_send_describe_prepared(pgconn): + pgconn.send_prepare(b"prep", b"select $1::int8 + $2::int8 as fld") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_describe_prepared(b"prep") + (res,) = execute_wait(pgconn) + assert res.nfields == 1 + assert res.ntuples == 0 + assert res.fname(0) == b"fld" + assert res.ftype(0) == 20 + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_describe_prepared(b"prep") + + +@pytest.mark.crdb_skip("server-side cursor") +def test_send_describe_portal(pgconn): + res = pgconn.exec_( + b""" + begin; + declare cur cursor for select * from generate_series(1,10) foo; + """ + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_describe_portal(b"cur") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + assert res.nfields == 1 + assert res.fname(0) == b"foo" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_describe_portal(b"cur") diff --git a/tests/pq/test_conninfo.py b/tests/pq/test_conninfo.py new file mode 100644 index 0000000..64d8b8f --- /dev/null +++ b/tests/pq/test_conninfo.py @@ -0,0 +1,48 @@ +import pytest + +import psycopg +from psycopg import pq + + +def test_defaults(monkeypatch): + monkeypatch.setenv("PGPORT", "15432") + defs = pq.Conninfo.get_defaults() + assert len(defs) > 20 + port = [d for d in defs if d.keyword == b"port"][0] + assert port.envvar == b"PGPORT" + assert port.compiled == b"5432" + assert port.val == b"15432" + assert port.label == b"Database-Port" + assert port.dispchar == b"" + assert port.dispsize == 6 + + +@pytest.mark.libpq(">= 10") +def test_conninfo_parse(): + infos = pq.Conninfo.parse( + b"postgresql://host1:123,host2:456/somedb" + b"?target_session_attrs=any&application_name=myapp" + ) + info = {i.keyword: i.val for i in infos if i.val is not None} + assert info[b"host"] == b"host1,host2" + assert info[b"port"] == b"123,456" + assert info[b"dbname"] == b"somedb" + assert info[b"application_name"] == b"myapp" + + +@pytest.mark.libpq("< 10") +def test_conninfo_parse_96(): + conninfo = pq.Conninfo.parse( + b"postgresql://other@localhost/otherdb" + b"?connect_timeout=10&application_name=myapp" + ) + info = {i.keyword: i.val for i in conninfo if i.val is not None} + assert info[b"host"] == b"localhost" + assert info[b"dbname"] == b"otherdb" + assert info[b"application_name"] == b"myapp" + + +def test_conninfo_parse_bad(): + with pytest.raises(psycopg.OperationalError) as e: + pq.Conninfo.parse(b"bad_conninfo=") + assert "bad_conninfo" in str(e.value) diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py new file mode 100644 index 0000000..383d272 --- /dev/null +++ b/tests/pq/test_copy.py @@ -0,0 +1,174 @@ +import pytest + +import psycopg +from psycopg import pq + +pytestmark = pytest.mark.crdb_skip("copy") + +sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" + +sample_tabledef = "col1 int primary key, col2 int, data text" + +sample_text = b"""\ +10\t20\thello +40\t\\N\tworld +""" + +sample_binary_value = """ +5047 434f 5059 0aff 0d0a 00 +00 0000 0000 0000 00 +00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f + +0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64 + +ff ff +""" + +sample_binary_rows = [ + bytes.fromhex("".join(row.split())) for row in sample_binary_value.split("\n\n") +] + +sample_binary = b"".join(sample_binary_rows) + + +def test_put_data_no_copy(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_data(b"wat") + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_data(b"wat") + + +def test_put_end_no_copy(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_end() + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_end() + + +def test_copy_out(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)} +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end() + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_( + b"select min(col1), max(col1), count(*), max(length(data)) from copy_in" + ) + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + assert res.get_value(0, 1) == b"199" + assert res.get_value(0, 2) == b"200" + assert res.get_value(0, 3) == b"199" + + +def test_copy_out_err(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\thardly a number\tnope +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end() + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"hardly a number" in res.error_message + + res = pgconn.exec_(b"select count(*) from copy_in") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + + +def test_copy_out_error_end(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)} +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end(b"nuttengoggenio") + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"nuttengoggenio" in res.error_message + + res = pgconn.exec_(b"select count(*) from copy_in") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + + +def test_get_data_no_copy(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.get_copy_data(0) + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.get_copy_data(0) + + +@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY]) +def test_copy_out_read(pgconn, format): + stmt = f"copy ({sample_values}) to stdout (format {format.name})" + res = pgconn.exec_(stmt.encode("ascii")) + assert res.status == pq.ExecStatus.COPY_OUT + assert res.binary_tuples == format + + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + for row in want: + nbytes, data = pgconn.get_copy_data(0) + assert nbytes == len(data) + assert data == row + + assert pgconn.get_copy_data(0) == (-1, b"") + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + +def ensure_table(pgconn, tabledef, name="copy_in"): + pgconn.exec_(f"drop table if exists {name}".encode("ascii")) + pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii")) diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py new file mode 100644 index 0000000..ad88d8a --- /dev/null +++ b/tests/pq/test_escaping.py @@ -0,0 +1,188 @@ +import pytest + +import psycopg +from psycopg import pq + +from ..fix_crdb import crdb_scs_off + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b"''"), + (b"hello", b"'hello'"), + (b"foo'bar", b"'foo''bar'"), + (b"foo\\bar", b" E'foo\\\\bar'"), + ], +) +def test_escape_literal(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_literal(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_literal_1char(pgconn, scs): + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + esc = pq.Escaping(pgconn) + special = {b"'": b"''''", b"\\": b" E'\\\\'"} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_literal(data) + exp = special.get(data) or b"'%s'" % data + assert rv == exp + + +def test_escape_literal_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg.OperationalError): + esc.escape_literal(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_literal(b"hi") + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b'""'), + (b"hello", b'"hello"'), + (b'foo"bar', b'"foo""bar"'), + (b"foo\\bar", b'"foo\\bar"'), + ], +) +def test_escape_identifier(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_identifier(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_identifier_1char(pgconn, scs): + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + esc = pq.Escaping(pgconn) + special = {b'"': b'""""', b"\\": b'"\\"'} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_identifier(data) + exp = special.get(data) or b'"%s"' % data + assert rv == exp + + +def test_escape_identifier_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg.OperationalError): + esc.escape_identifier(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_identifier(b"hi") + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b""), + (b"hello", b"hello"), + (b"foo'bar", b"foo''bar"), + (b"foo\\bar", b"foo\\bar"), + ], +) +def test_escape_string(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_string(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_string_1char(pgconn, scs): + esc = pq.Escaping(pgconn) + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + special = {b"'": b"''", b"\\": b"\\" if scs == "on" else b"\\\\"} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_string(data) + exp = special.get(data) or b"%s" % data + assert rv == exp + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b""), + (b"hello", b"hello"), + (b"foo'bar", b"foo''bar"), + # This libpq function behaves unpredictably when not passed a conn + (b"foo\\bar", (b"foo\\\\bar", b"foo\\bar")), + ], +) +def test_escape_string_noconn(data, want): + esc = pq.Escaping() + out = esc.escape_string(data) + if isinstance(want, bytes): + assert out == want + else: + assert out in want + + +def test_escape_string_badconn(pgconn): + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_string(b"hi") + + +def test_escape_string_badenc(pgconn): + res = pgconn.exec_(b"set client_encoding to 'UTF8'") + assert res.status == pq.ExecStatus.COMMAND_OK + data = "\u20ac".encode()[:-1] + esc = pq.Escaping(pgconn) + with pytest.raises(psycopg.OperationalError): + esc.escape_string(data) + + +@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"]) +def test_escape_bytea(pgconn, data): + exp = rb"\x" + b"".join(b"%02x" % c for c in data) + esc = pq.Escaping(pgconn) + rv = esc.escape_bytea(data) + assert rv == exp + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_bytea(data) + + +def test_escape_noconn(pgconn): + data = bytes(range(256)) + esc = pq.Escaping() + escdata = esc.escape_bytea(data) + res = pgconn.exec_params(b"select '%s'::bytea" % escdata, [], result_format=1) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == data + + +def test_escape_1char(pgconn): + esc = pq.Escaping(pgconn) + for c in range(256): + rv = esc.escape_bytea(bytes([c])) + exp = rb"\x%02x" % c + assert rv == exp + + +@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"]) +def test_unescape_bytea(pgconn, data): + enc = rb"\x" + b"".join(b"%02x" % c for c in data) + esc = pq.Escaping(pgconn) + rv = esc.unescape_bytea(enc) + assert rv == data + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.unescape_bytea(data) diff --git a/tests/pq/test_exec.py b/tests/pq/test_exec.py new file mode 100644 index 0000000..86c30c0 --- /dev/null +++ b/tests/pq/test_exec.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +import pytest + +import psycopg +from psycopg import pq + + +def test_exec_none(pgconn): + with pytest.raises(TypeError): + pgconn.exec_(None) + + +def test_exec(pgconn): + res = pgconn.exec_(b"select 'hel' || 'lo'") + assert res.get_value(0, 0) == b"hello" + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.exec_(b"select 'hello'") + + +def test_exec_params(pgconn): + res = pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"]) + + +def test_exec_params_empty(pgconn): + res = pgconn.exec_params(b"select 8::int", []) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + + +def test_exec_params_types(pgconn): + res = pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700, 23]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + assert res.ftype(0) == 1700 + assert res.get_value(0, 1) == b"8" + assert res.ftype(1) == 23 + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700]) + + +def test_exec_params_nulls(pgconn): + res = pgconn.exec_params(b"select $1::text, $2::text, $3::text", [b"hi", b"", None]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"hi" + assert res.get_value(0, 1) == b"" + assert res.get_value(0, 2) is None + + +def test_exec_params_binary_in(pgconn): + val = b"foo\00bar" + res = pgconn.exec_params( + b"select length($1::bytea), length($2::bytea)", + [val, val], + param_formats=[0, 1], + ) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"3" + assert res.get_value(0, 1) == b"7" + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1]) + + +@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]) +def test_exec_params_binary_out(pgconn, fmt, out): + val = b"foo\00bar" + res = pgconn.exec_params( + b"select $1::bytea", [val], param_formats=[1], result_format=fmt + ) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == out + + +def test_prepare(pgconn): + res = pgconn.prepare(b"prep", b"select $1::int + $2::int") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"prep", [b"3", b"5"]) + assert res.get_value(0, 0) == b"8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.prepare(b"prep", b"select $1::int + $2::int") + with pytest.raises(psycopg.OperationalError): + pgconn.exec_prepared(b"prep", [b"3", b"5"]) + + +def test_prepare_types(pgconn): + res = pgconn.prepare(b"prep", b"select $1 + $2", [23, 23]) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"prep", [b"3", b"5"]) + assert res.get_value(0, 0) == b"8" + + +def test_exec_prepared_binary_in(pgconn): + val = b"foo\00bar" + res = pgconn.prepare(b"", b"select length($1::bytea), length($2::bytea)") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"", [val, val], param_formats=[0, 1]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"3" + assert res.get_value(0, 1) == b"7" + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1]) + + +@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]) +def test_exec_prepared_binary_out(pgconn, fmt, out): + val = b"foo\00bar" + res = pgconn.prepare(b"", b"select $1::bytea") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"", [val], param_formats=[1], result_format=fmt) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == out + + +@pytest.mark.crdb_skip("server-side cursor") +def test_describe_portal(pgconn): + res = pgconn.exec_( + b""" + begin; + declare cur cursor for select * from generate_series(1,10) foo; + """ + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.describe_portal(b"cur") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + assert res.nfields == 1 + assert res.fname(0) == b"foo" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.describe_portal(b"cur") diff --git a/tests/pq/test_misc.py b/tests/pq/test_misc.py new file mode 100644 index 0000000..599758f --- /dev/null +++ b/tests/pq/test_misc.py @@ -0,0 +1,83 @@ +import pytest + +import psycopg +from psycopg import pq + + +def test_error_message(pgconn): + res = pgconn.exec_(b"wat") + assert res.status == pq.ExecStatus.FATAL_ERROR + msg = pq.error_message(pgconn) + assert "wat" in msg + assert msg == pq.error_message(res) + primary = res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) + assert primary.decode("ascii") in msg + + with pytest.raises(TypeError): + pq.error_message(None) # type: ignore[arg-type] + + res.clear() + assert pq.error_message(res) == "no details available" + pgconn.finish() + assert "NULL" in pq.error_message(pgconn) + + +@pytest.mark.crdb_skip("encoding") +def test_error_message_encoding(pgconn): + res = pgconn.exec_(b"set client_encoding to latin9") + assert res.status == pq.ExecStatus.COMMAND_OK + + res = pgconn.exec_('select 1 from "foo\u20acbar"'.encode("latin9")) + assert res.status == pq.ExecStatus.FATAL_ERROR + + msg = pq.error_message(pgconn) + assert "foo\u20acbar" in msg + + msg = pq.error_message(res) + assert "foo\ufffdbar" in msg + + msg = pq.error_message(res, encoding="latin9") + assert "foo\u20acbar" in msg + + msg = pq.error_message(res, encoding="ascii") + assert "foo\ufffdbar" in msg + + +def test_make_empty_result(pgconn): + pgconn.exec_(b"wat") + res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"wat" in res.error_message + + pgconn.finish() + res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) + assert res.status == pq.ExecStatus.FATAL_ERROR + assert res.error_message == b"" + + +def test_result_set_attrs(pgconn): + res = pgconn.make_empty_result(pq.ExecStatus.COPY_OUT) + assert res.status == pq.ExecStatus.COPY_OUT + + attrs = [ + pq.PGresAttDesc(b"an_int", 0, 0, 0, 23, 0, 0), + pq.PGresAttDesc(b"a_num", 0, 0, 0, 1700, 0, 0), + pq.PGresAttDesc(b"a_bin_text", 0, 0, 1, 25, 0, 0), + ] + res.set_attributes(attrs) + assert res.nfields == 3 + + assert res.fname(0) == b"an_int" + assert res.fname(1) == b"a_num" + assert res.fname(2) == b"a_bin_text" + + assert res.fformat(0) == 0 + assert res.fformat(1) == 0 + assert res.fformat(2) == 1 + + assert res.ftype(0) == 23 + assert res.ftype(1) == 1700 + assert res.ftype(2) == 25 + + with pytest.raises(psycopg.OperationalError): + res.set_attributes(attrs) diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py new file mode 100644 index 0000000..0566151 --- /dev/null +++ b/tests/pq/test_pgconn.py @@ -0,0 +1,585 @@ +import os +import sys +import ctypes +import logging +import weakref +from select import select + +import pytest + +import psycopg +from psycopg import pq +import psycopg.generators + +from ..utils import gc_collect + + +def test_connectdb(dsn): + conn = pq.PGconn.connect(dsn.encode()) + assert conn.status == pq.ConnStatus.OK, conn.error_message + + +def test_connectdb_error(): + conn = pq.PGconn.connect(b"dbname=psycopg_test_not_for_real") + assert conn.status == pq.ConnStatus.BAD + + +@pytest.mark.parametrize("baddsn", [None, 42]) +def test_connectdb_badtype(baddsn): + with pytest.raises(TypeError): + pq.PGconn.connect(baddsn) + + +def test_connect_async(dsn): + conn = pq.PGconn.connect_start(dsn.encode()) + conn.nonblocking = 1 + while True: + assert conn.status != pq.ConnStatus.BAD + rv = conn.connect_poll() + if rv == pq.PollingStatus.OK: + break + elif rv == pq.PollingStatus.READING: + select([conn.socket], [], []) + elif rv == pq.PollingStatus.WRITING: + select([], [conn.socket], []) + else: + assert False, rv + + assert conn.status == pq.ConnStatus.OK + + conn.finish() + with pytest.raises(psycopg.OperationalError): + conn.connect_poll() + + +@pytest.mark.crdb("skip", reason="connects to any db name") +def test_connect_async_bad(dsn): + parsed_dsn = {e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val} + parsed_dsn[b"dbname"] = b"psycopg_test_not_for_real" + dsn = b" ".join(b"%s='%s'" % item for item in parsed_dsn.items()) + conn = pq.PGconn.connect_start(dsn) + while True: + assert conn.status != pq.ConnStatus.BAD, conn.error_message + rv = conn.connect_poll() + if rv == pq.PollingStatus.FAILED: + break + elif rv == pq.PollingStatus.READING: + select([conn.socket], [], []) + elif rv == pq.PollingStatus.WRITING: + select([], [conn.socket], []) + else: + assert False, rv + + assert conn.status == pq.ConnStatus.BAD + + +def test_finish(pgconn): + assert pgconn.status == pq.ConnStatus.OK + pgconn.finish() + assert pgconn.status == pq.ConnStatus.BAD + pgconn.finish() + assert pgconn.status == pq.ConnStatus.BAD + + +@pytest.mark.slow +def test_weakref(dsn): + conn = pq.PGconn.connect(dsn.encode()) + w = weakref.ref(conn) + conn.finish() + del conn + gc_collect() + assert w() is None + + +@pytest.mark.skipif( + sys.platform == "win32" + and os.environ.get("CI") == "true" + and pq.__impl__ != "python", + reason="can't figure out how to make ctypes run, don't care", +) +def test_pgconn_ptr(pgconn, libpq): + assert isinstance(pgconn.pgconn_ptr, int) + + f = libpq.PQserverVersion + f.argtypes = [ctypes.c_void_p] + f.restype = ctypes.c_int + ver = f(pgconn.pgconn_ptr) + assert ver == pgconn.server_version + + pgconn.finish() + assert pgconn.pgconn_ptr is None + + +def test_info(dsn, pgconn): + info = pgconn.info + assert len(info) > 20 + dbname = [d for d in info if d.keyword == b"dbname"][0] + assert dbname.envvar == b"PGDATABASE" + assert dbname.label == b"Database-Name" + assert dbname.dispchar == b"" + assert dbname.dispsize == 20 + + parsed = pq.Conninfo.parse(dsn.encode()) + # take the name and the user either from params or from env vars + name = [ + o.val or os.environ.get(o.envvar.decode(), "").encode() + for o in parsed + if o.keyword == b"dbname" and o.envvar + ][0] + user = [ + o.val or os.environ.get(o.envvar.decode(), "").encode() + for o in parsed + if o.keyword == b"user" and o.envvar + ][0] + assert dbname.val == (name or user) + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.info + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_reset(pgconn): + assert pgconn.status == pq.ConnStatus.OK + pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())") + assert pgconn.status == pq.ConnStatus.BAD + pgconn.reset() + assert pgconn.status == pq.ConnStatus.OK + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.reset() + + assert pgconn.status == pq.ConnStatus.BAD + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_reset_async(pgconn): + assert pgconn.status == pq.ConnStatus.OK + pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())") + assert pgconn.status == pq.ConnStatus.BAD + pgconn.reset_start() + while True: + rv = pgconn.reset_poll() + if rv == pq.PollingStatus.READING: + select([pgconn.socket], [], []) + elif rv == pq.PollingStatus.WRITING: + select([], [pgconn.socket], []) + else: + break + + assert rv == pq.PollingStatus.OK + assert pgconn.status == pq.ConnStatus.OK + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.reset_start() + + with pytest.raises(psycopg.OperationalError): + pgconn.reset_poll() + + +def test_ping(dsn): + rv = pq.PGconn.ping(dsn.encode()) + assert rv == pq.Ping.OK + + rv = pq.PGconn.ping(b"port=9999") + assert rv == pq.Ping.NO_RESPONSE + + +def test_db(pgconn): + name = [o.val for o in pgconn.info if o.keyword == b"dbname"][0] + assert pgconn.db == name + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.db + + +def test_user(pgconn): + user = [o.val for o in pgconn.info if o.keyword == b"user"][0] + assert pgconn.user == user + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.user + + +def test_password(pgconn): + # not in info + assert isinstance(pgconn.password, bytes) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.password + + +def test_host(pgconn): + # might be not in info + assert isinstance(pgconn.host, bytes) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.host + + +@pytest.mark.libpq(">= 12") +def test_hostaddr(pgconn): + # not in info + assert isinstance(pgconn.hostaddr, bytes), pgconn.hostaddr + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.hostaddr + + +@pytest.mark.libpq("< 12") +def test_hostaddr_missing(pgconn): + with pytest.raises(psycopg.NotSupportedError): + pgconn.hostaddr + + +def test_port(pgconn): + port = [o.val for o in pgconn.info if o.keyword == b"port"][0] + assert pgconn.port == port + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.port + + +@pytest.mark.libpq("< 14") +def test_tty(pgconn): + tty = [o.val for o in pgconn.info if o.keyword == b"tty"][0] + assert pgconn.tty == tty + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.tty + + +@pytest.mark.libpq(">= 14") +def test_tty_noop(pgconn): + assert not any(o.val for o in pgconn.info if o.keyword == b"tty") + assert pgconn.tty == b"" + + +def test_transaction_status(pgconn): + assert pgconn.transaction_status == pq.TransactionStatus.IDLE + pgconn.exec_(b"begin") + assert pgconn.transaction_status == pq.TransactionStatus.INTRANS + pgconn.send_query(b"select 1") + assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE + psycopg.waiting.wait(psycopg.generators.execute(pgconn), pgconn.socket) + assert pgconn.transaction_status == pq.TransactionStatus.INTRANS + pgconn.finish() + assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN + + +def test_parameter_status(dsn, monkeypatch): + monkeypatch.setenv("PGAPPNAME", "psycopg tests") + pgconn = pq.PGconn.connect(dsn.encode()) + assert pgconn.parameter_status(b"application_name") == b"psycopg tests" + assert pgconn.parameter_status(b"wat") is None + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.parameter_status(b"application_name") + + +@pytest.mark.crdb_skip("encoding") +def test_encoding(pgconn): + res = pgconn.exec_(b"set client_encoding to latin1") + assert res.status == pq.ExecStatus.COMMAND_OK + assert pgconn.parameter_status(b"client_encoding") == b"LATIN1" + + res = pgconn.exec_(b"set client_encoding to 'utf-8'") + assert res.status == pq.ExecStatus.COMMAND_OK + assert pgconn.parameter_status(b"client_encoding") == b"UTF8" + + res = pgconn.exec_(b"set client_encoding to wat") + assert res.status == pq.ExecStatus.FATAL_ERROR + assert pgconn.parameter_status(b"client_encoding") == b"UTF8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.parameter_status(b"client_encoding") + + +def test_protocol_version(pgconn): + assert pgconn.protocol_version == 3 + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.protocol_version + + +def test_server_version(pgconn): + assert pgconn.server_version >= 90400 + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.server_version + + +def test_socket(pgconn): + socket = pgconn.socket + assert socket > 0 + pgconn.exec_(f"select pg_terminate_backend({pgconn.backend_pid})".encode()) + # TODO: on my box it raises OperationalError as it should. Not on Travis, + # so let's see if at least an ok value comes out of it. + try: + assert pgconn.socket == socket + except psycopg.OperationalError: + pass + + +def test_error_message(pgconn): + assert pgconn.error_message == b"" + res = pgconn.exec_(b"wat") + assert res.status == pq.ExecStatus.FATAL_ERROR + msg = pgconn.error_message + assert b"wat" in msg + pgconn.finish() + assert b"NULL" in pgconn.error_message # TODO: i10n? + + +def test_backend_pid(pgconn): + assert isinstance(pgconn.backend_pid, int) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.backend_pid + + +def test_needs_password(pgconn): + # assume connection worked so an eventually needed password wasn't missing + assert pgconn.needs_password is False + pgconn.finish() + pgconn.needs_password + + +def test_used_password(pgconn, dsn, monkeypatch): + assert isinstance(pgconn.used_password, bool) + + # Assume that if a password was passed then it was needed. + # Note that the server may still need a password passed via pgpass + # so it may be that has_password is false but still a password was + # requested by the server and passed by libpq. + info = pq.Conninfo.parse(dsn.encode()) + has_password = ( + "PGPASSWORD" in os.environ + or [i for i in info if i.keyword == b"password"][0].val is not None + ) + if has_password: + assert pgconn.used_password + + pgconn.finish() + pgconn.used_password + + +def test_ssl_in_use(pgconn): + assert isinstance(pgconn.ssl_in_use, bool) + + # If connecting via socket then ssl is not in use + if pgconn.host.startswith(b"/"): + assert not pgconn.ssl_in_use + else: + sslmode = [i.val for i in pgconn.info if i.keyword == b"sslmode"][0] + if sslmode not in (b"disable", b"allow", b"prefer"): + assert pgconn.ssl_in_use + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.ssl_in_use + + +def test_set_single_row_mode(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.set_single_row_mode() + + pgconn.send_query(b"select 1") + pgconn.set_single_row_mode() + + +def test_cancel(pgconn): + cancel = pgconn.get_cancel() + cancel.cancel() + cancel.cancel() + pgconn.finish() + cancel.cancel() + with pytest.raises(psycopg.OperationalError): + pgconn.get_cancel() + + +def test_cancel_free(pgconn): + cancel = pgconn.get_cancel() + cancel.free() + with pytest.raises(psycopg.OperationalError): + cancel.cancel() + cancel.free() + + +@pytest.mark.crdb_skip("notify") +def test_notify(pgconn): + assert pgconn.notifies() is None + + pgconn.exec_(b"listen foo") + pgconn.exec_(b"listen bar") + pgconn.exec_(b"notify foo, '1'") + pgconn.exec_(b"notify bar, '2'") + pgconn.exec_(b"notify foo, '3'") + + n = pgconn.notifies() + assert n.relname == b"foo" + assert n.be_pid == pgconn.backend_pid + assert n.extra == b"1" + + n = pgconn.notifies() + assert n.relname == b"bar" + assert n.be_pid == pgconn.backend_pid + assert n.extra == b"2" + + n = pgconn.notifies() + assert n.relname == b"foo" + assert n.be_pid == pgconn.backend_pid + assert n.extra == b"3" + + assert pgconn.notifies() is None + + +@pytest.mark.crdb_skip("do") +def test_notice_nohandler(pgconn): + pgconn.exec_(b"set client_min_messages to notice") + res = pgconn.exec_( + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" + ) + assert res.status == pq.ExecStatus.COMMAND_OK + + +@pytest.mark.crdb_skip("do") +def test_notice(pgconn): + msgs = [] + + def callback(res): + assert res.status == pq.ExecStatus.NONFATAL_ERROR + msgs.append(res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)) + + pgconn.exec_(b"set client_min_messages to notice") + pgconn.notice_handler = callback + res = pgconn.exec_( + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" + ) + + assert res.status == pq.ExecStatus.COMMAND_OK + assert msgs and msgs[0] == b"hello notice" + + +@pytest.mark.crdb_skip("do") +def test_notice_error(pgconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + def callback(res): + raise Exception("hello error") + + pgconn.exec_(b"set client_min_messages to notice") + pgconn.notice_handler = callback + res = pgconn.exec_( + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" + ) + + assert res.status == pq.ExecStatus.COMMAND_OK + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello error" in rec.message + + +@pytest.mark.libpq("< 14") +@pytest.mark.skipif("sys.platform != 'linux'") +def test_trace_pre14(pgconn, tmp_path): + tracef = tmp_path / "trace" + with tracef.open("w") as f: + pgconn.trace(f.fileno()) + with pytest.raises(psycopg.NotSupportedError): + pgconn.set_trace_flags(0) + pgconn.exec_(b"select 1") + pgconn.untrace() + pgconn.exec_(b"select 2") + traces = tracef.read_text() + assert "select 1" in traces + assert "select 2" not in traces + + +@pytest.mark.libpq(">= 14") +@pytest.mark.skipif("sys.platform != 'linux'") +def test_trace(pgconn, tmp_path): + tracef = tmp_path / "trace" + with tracef.open("w") as f: + pgconn.trace(f.fileno()) + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + pgconn.exec_(b"select 1::int4 as foo") + pgconn.untrace() + pgconn.exec_(b"select 2::int4 as foo") + traces = [line.split("\t") for line in tracef.read_text().splitlines()] + assert traces == [ + ["F", "26", "Query", ' "select 1::int4 as foo"'], + ["B", "28", "RowDescription", ' 1 "foo" NNNN 0 NNNN 4 -1 0'], + ["B", "11", "DataRow", " 1 1 '1'"], + ["B", "13", "CommandComplete", ' "SELECT 1"'], + ["B", "5", "ReadyForQuery", " I"], + ] + + +@pytest.mark.skipif("sys.platform == 'linux'") +def test_trace_nonlinux(pgconn): + with pytest.raises(psycopg.NotSupportedError): + pgconn.trace(1) + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password(pgconn): + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5") + assert enc == b"md594839d658c28a357126f105b9cb14cfc" + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password_scram(pgconn): + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256") + assert enc.startswith(b"SCRAM-SHA-256$") + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password_badalgo(pgconn): + with pytest.raises(psycopg.OperationalError): + assert pgconn.encrypt_password(b"psycopg2", b"ashesh", b"wat") + + +@pytest.mark.libpq(">= 10") +@pytest.mark.crdb_skip("password_encryption") +def test_encrypt_password_query(pgconn): + res = pgconn.exec_(b"set password_encryption to 'md5'") + assert res.status == pq.ExecStatus.COMMAND_OK, pgconn.error_message.decode() + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh") + assert enc == b"md594839d658c28a357126f105b9cb14cfc" + + res = pgconn.exec_(b"set password_encryption to 'scram-sha-256'") + assert res.status == pq.ExecStatus.COMMAND_OK + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh") + assert enc.startswith(b"SCRAM-SHA-256$") + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password_closed(pgconn): + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + assert pgconn.encrypt_password(b"psycopg2", b"ashesh") + + +@pytest.mark.libpq("< 10") +def test_encrypt_password_not_supported(pgconn): + # it might even be supported, but not worth the lifetime + with pytest.raises(psycopg.NotSupportedError): + pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5") + + with pytest.raises(psycopg.NotSupportedError): + pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256") + + +def test_str(pgconn, dsn): + assert "[IDLE]" in str(pgconn) + pgconn.finish() + assert "[BAD]" in str(pgconn) + + pgconn2 = pq.PGconn.connect_start(dsn.encode()) + assert "[" in str(pgconn2) + assert "[IDLE]" not in str(pgconn2) diff --git a/tests/pq/test_pgresult.py b/tests/pq/test_pgresult.py new file mode 100644 index 0000000..3ad818d --- /dev/null +++ b/tests/pq/test_pgresult.py @@ -0,0 +1,207 @@ +import ctypes +import pytest + +from psycopg import pq + + +@pytest.mark.parametrize( + "command, status", + [ + (b"", "EMPTY_QUERY"), + (b"select 1", "TUPLES_OK"), + (b"set timezone to utc", "COMMAND_OK"), + (b"wat", "FATAL_ERROR"), + ], +) +def test_status(pgconn, command, status): + res = pgconn.exec_(command) + assert res.status == getattr(pq.ExecStatus, status) + assert status in repr(res) + + +def test_clear(pgconn): + res = pgconn.exec_(b"select 1") + assert res.status == pq.ExecStatus.TUPLES_OK + res.clear() + assert res.status == pq.ExecStatus.FATAL_ERROR + res.clear() + assert res.status == pq.ExecStatus.FATAL_ERROR + + +def test_pgresult_ptr(pgconn, libpq): + res = pgconn.exec_(b"select 1") + assert isinstance(res.pgresult_ptr, int) + + f = libpq.PQcmdStatus + f.argtypes = [ctypes.c_void_p] + f.restype = ctypes.c_char_p + assert f(res.pgresult_ptr) == b"SELECT 1" + + res.clear() + assert res.pgresult_ptr is None + + +def test_error_message(pgconn): + res = pgconn.exec_(b"select 1") + assert res.error_message == b"" + res = pgconn.exec_(b"select wat") + assert b"wat" in res.error_message + res.clear() + assert res.error_message == b"" + + +def test_error_field(pgconn): + res = pgconn.exec_(b"select wat") + # https://github.com/cockroachdb/cockroach/issues/81794 + assert ( + res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED) + or res.error_field(pq.DiagnosticField.SEVERITY) + ) == b"ERROR" + assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703" + assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) + res.clear() + assert res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) is None + + +@pytest.mark.parametrize("n", range(4)) +def test_ntuples(pgconn, n): + res = pgconn.exec_params(b"select generate_series(1, $1)", [str(n).encode("ascii")]) + assert res.ntuples == n + res.clear() + assert res.ntuples == 0 + + +def test_nfields(pgconn): + res = pgconn.exec_(b"select wat") + assert res.nfields == 0 + res = pgconn.exec_(b"select 1, 2, 3") + assert res.nfields == 3 + res.clear() + assert res.nfields == 0 + + +def test_fname(pgconn): + res = pgconn.exec_(b'select 1 as foo, 2 as "BAR"') + assert res.fname(0) == b"foo" + assert res.fname(1) == b"BAR" + assert res.fname(2) is None + assert res.fname(-1) is None + res.clear() + assert res.fname(0) is None + + +@pytest.mark.crdb("skip", reason="ftable") +def test_ftable_and_col(pgconn): + res = pgconn.exec_( + b""" + drop table if exists t1, t2; + create table t1 as select 1 as f1; + create table t2 as select 2 as f2, 3 as f3; + """ + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_( + b"select f1, f3, 't1'::regclass::oid, 't2'::regclass::oid from t1, t2" + ) + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + + assert res.ftable(0) == int(res.get_value(0, 2).decode("ascii")) + assert res.ftable(1) == int(res.get_value(0, 3).decode("ascii")) + assert res.ftablecol(0) == 1 + assert res.ftablecol(1) == 2 + res.clear() + assert res.ftable(0) == 0 + assert res.ftablecol(0) == 0 + + +@pytest.mark.parametrize("fmt", (0, 1)) +def test_fformat(pgconn, fmt): + res = pgconn.exec_params(b"select 1", [], result_format=fmt) + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.fformat(0) == fmt + assert res.binary_tuples == fmt + res.clear() + assert res.fformat(0) == 0 + assert res.binary_tuples == 0 + + +def test_ftype(pgconn): + res = pgconn.exec_(b"select 1::int4, 1::numeric, 1::text") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.ftype(0) == 23 + assert res.ftype(1) == 1700 + assert res.ftype(2) == 25 + res.clear() + assert res.ftype(0) == 0 + + +def test_fmod(pgconn): + res = pgconn.exec_(b"select 1::int, 1::numeric(10), 1::numeric(10,2)") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.fmod(0) == -1 + assert res.fmod(1) == 0xA0004 + assert res.fmod(2) == 0xA0006 + res.clear() + assert res.fmod(0) == 0 + + +def test_fsize(pgconn): + res = pgconn.exec_(b"select 1::int4, 1::bigint, 1::text") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.fsize(0) == 4 + assert res.fsize(1) == 8 + assert res.fsize(2) == -1 + res.clear() + assert res.fsize(0) == 0 + + +def test_get_value(pgconn): + res = pgconn.exec_(b"select 'a', '', NULL") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"a" + assert res.get_value(0, 1) == b"" + assert res.get_value(0, 2) is None + res.clear() + assert res.get_value(0, 0) is None + + +def test_nparams_types(pgconn): + res = pgconn.prepare(b"", b"select $1::int4, $2::text") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.describe_prepared(b"") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + assert res.nparams == 2 + assert res.param_type(0) == 23 + assert res.param_type(1) == 25 + + res.clear() + assert res.nparams == 0 + assert res.param_type(0) == 0 + + +def test_command_status(pgconn): + res = pgconn.exec_(b"select 1") + assert res.command_status == b"SELECT 1" + res = pgconn.exec_(b"set timezone to utc") + assert res.command_status == b"SET" + res.clear() + assert res.command_status is None + + +def test_command_tuples(pgconn): + res = pgconn.exec_(b"set timezone to utf8") + assert res.command_tuples is None + res = pgconn.exec_(b"select * from generate_series(1, 10)") + assert res.command_tuples == 10 + res.clear() + assert res.command_tuples is None + + +def test_oid_value(pgconn): + res = pgconn.exec_(b"select 1") + assert res.oid_value == 0 + res.clear() + assert res.oid_value == 0 diff --git a/tests/pq/test_pipeline.py b/tests/pq/test_pipeline.py new file mode 100644 index 0000000..00cd54a --- /dev/null +++ b/tests/pq/test_pipeline.py @@ -0,0 +1,161 @@ +import pytest + +import psycopg +from psycopg import pq + + +@pytest.mark.libpq("< 14") +def test_old_libpq(pgconn): + assert pgconn.pipeline_status == 0 + with pytest.raises(psycopg.NotSupportedError): + pgconn.enter_pipeline_mode() + with pytest.raises(psycopg.NotSupportedError): + pgconn.exit_pipeline_mode() + with pytest.raises(psycopg.NotSupportedError): + pgconn.pipeline_sync() + with pytest.raises(psycopg.NotSupportedError): + pgconn.send_flush_request() + + +@pytest.mark.libpq(">= 14") +def test_work_in_progress(pgconn): + assert not pgconn.nonblocking + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"select $1", [b"1"]) + with pytest.raises(psycopg.OperationalError, match="cannot exit pipeline mode"): + pgconn.exit_pipeline_mode() + + +@pytest.mark.libpq(">= 14") +def test_multi_pipelines(pgconn): + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"select $1", [b"1"], param_types=[25]) + pgconn.pipeline_sync() + pgconn.send_query_params(b"select $1", [b"2"], param_types=[25]) + pgconn.pipeline_sync() + + # result from first query + result1 = pgconn.get_result() + assert result1 is not None + assert result1.status == pq.ExecStatus.TUPLES_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # first sync result + sync_result = pgconn.get_result() + assert sync_result is not None + assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC + + # result from second query + result2 = pgconn.get_result() + assert result2 is not None + assert result2.status == pq.ExecStatus.TUPLES_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # second sync result + sync_result = pgconn.get_result() + assert sync_result is not None + assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC + + # pipeline still ON + assert pgconn.pipeline_status == pq.PipelineStatus.ON + + pgconn.exit_pipeline_mode() + + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + + assert result1.get_value(0, 0) == b"1" + assert result2.get_value(0, 0) == b"2" + + +@pytest.mark.libpq(">= 14") +def test_flush_request(pgconn): + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"select $1", [b"1"], param_types=[25]) + pgconn.send_flush_request() + r = pgconn.get_result() + assert r.status == pq.ExecStatus.TUPLES_OK + assert r.get_value(0, 0) == b"1" + pgconn.exit_pipeline_mode() + + +@pytest.fixture +def table(pgconn): + tablename = "pipeline" + pgconn.exec_(f"create table {tablename} (s text)".encode("ascii")) + yield tablename + pgconn.exec_(f"drop table if exists {tablename}".encode("ascii")) + + +@pytest.mark.libpq(">= 14") +def test_pipeline_abort(pgconn, table): + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"insert into pipeline values ($1)", [b"1"]) + pgconn.send_query_params(b"select no_such_function($1)", [b"1"]) + pgconn.send_query_params(b"insert into pipeline values ($1)", [b"2"]) + pgconn.pipeline_sync() + pgconn.send_query_params(b"insert into pipeline values ($1)", [b"3"]) + pgconn.pipeline_sync() + + # result from first INSERT + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.COMMAND_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # error result from second query (SELECT) + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.FATAL_ERROR + + # NULL signals end of result + assert pgconn.get_result() is None + + # pipeline should be aborted, due to previous error + assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED + + # result from second INSERT, aborted due to previous error + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.PIPELINE_ABORTED + + # NULL signals end of result + assert pgconn.get_result() is None + + # pipeline is still aborted + assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED + + # sync result + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.PIPELINE_SYNC + + # aborted flag is clear, pipeline is on again + assert pgconn.pipeline_status == pq.PipelineStatus.ON + + # result from the third INSERT + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.COMMAND_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # second sync result + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.PIPELINE_SYNC + + # NULL signals end of result + assert pgconn.get_result() is None + + pgconn.exit_pipeline_mode() diff --git a/tests/pq/test_pq.py b/tests/pq/test_pq.py new file mode 100644 index 0000000..076c3b6 --- /dev/null +++ b/tests/pq/test_pq.py @@ -0,0 +1,57 @@ +import os + +import pytest + +import psycopg +from psycopg import pq + +from ..utils import check_libpq_version + + +def test_version(): + rv = pq.version() + assert rv > 90500 + assert rv < 200000 # you are good for a while + + +def test_build_version(): + assert pq.__build_version__ and pq.__build_version__ >= 70400 + + +@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_BUILD')") +def test_want_built_version(): + want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_BUILD"] + got = pq.__build_version__ + assert not check_libpq_version(got, want) + + +@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_IMPORT')") +def test_want_import_version(): + want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_IMPORT"] + got = pq.version() + assert not check_libpq_version(got, want) + + +# Note: These tests are here because test_pipeline.py tests are all skipped +# when pipeline mode is not supported. + + +@pytest.mark.libpq(">= 14") +def test_pipeline_supported(conn): + assert psycopg.Pipeline.is_supported() + assert psycopg.AsyncPipeline.is_supported() + + with conn.pipeline(): + pass + + +@pytest.mark.libpq("< 14") +def test_pipeline_not_supported(conn): + assert not psycopg.Pipeline.is_supported() + assert not psycopg.AsyncPipeline.is_supported() + + with pytest.raises(psycopg.NotSupportedError) as exc: + with conn.pipeline(): + pass + + assert "too old" in str(exc.value) diff --git a/tests/scripts/bench-411.py b/tests/scripts/bench-411.py new file mode 100644 index 0000000..82ea451 --- /dev/null +++ b/tests/scripts/bench-411.py @@ -0,0 +1,300 @@ +import os +import sys +import time +import random +import asyncio +import logging +from enum import Enum +from typing import Any, Dict, List, Generator +from argparse import ArgumentParser, Namespace +from contextlib import contextmanager + +logger = logging.getLogger() +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", +) + + +class Driver(str, Enum): + psycopg2 = "psycopg2" + psycopg = "psycopg" + psycopg_async = "psycopg_async" + asyncpg = "asyncpg" + + +ids: List[int] = [] +data: List[Dict[str, Any]] = [] + + +def main() -> None: + + args = parse_cmdline() + + ids[:] = range(args.ntests) + data[:] = [ + dict( + id=i, + name="c%d" % i, + description="c%d" % i, + q=i * 10, + p=i * 20, + x=i * 30, + y=i * 40, + ) + for i in ids + ] + + # Must be done just on end + drop_at_the_end = args.drop + args.drop = False + + for i, name in enumerate(args.drivers): + if i == len(args.drivers) - 1: + args.drop = drop_at_the_end + + if name == Driver.psycopg2: + import psycopg2 # type: ignore + + run_psycopg2(psycopg2, args) + + elif name == Driver.psycopg: + import psycopg + + run_psycopg(psycopg, args) + + elif name == Driver.psycopg_async: + import psycopg + + if sys.platform == "win32": + if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + asyncio.set_event_loop_policy( + asyncio.WindowsSelectorEventLoopPolicy() + ) + + asyncio.run(run_psycopg_async(psycopg, args)) + + elif name == Driver.asyncpg: + import asyncpg # type: ignore + + asyncio.run(run_asyncpg(asyncpg, args)) + + else: + raise AssertionError(f"unknown driver: {name!r}") + + # Must be done just on start + args.create = False + + +table = """ +CREATE TABLE customer ( + id SERIAL NOT NULL, + name VARCHAR(255), + description VARCHAR(255), + q INTEGER, + p INTEGER, + x INTEGER, + y INTEGER, + z INTEGER, + PRIMARY KEY (id) +) +""" +drop = "DROP TABLE IF EXISTS customer" + +insert = """ +INSERT INTO customer (id, name, description, q, p, x, y) VALUES +(%(id)s, %(name)s, %(description)s, %(q)s, %(p)s, %(x)s, %(y)s) +""" + +select = """ +SELECT customer.id, customer.name, customer.description, customer.q, + customer.p, customer.x, customer.y, customer.z +FROM customer +WHERE customer.id = %(id)s +""" + + +@contextmanager +def time_log(message: str) -> Generator[None, None, None]: + start = time.monotonic() + yield + end = time.monotonic() + logger.info(f"Run {message} in {end-start} s") + + +def run_psycopg2(psycopg2: Any, args: Namespace) -> None: + logger.info("Running psycopg2") + + if args.create: + logger.info(f"inserting {args.ntests} test records") + with psycopg2.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + cursor.execute(table) + cursor.executemany(insert, data) + conn.commit() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + with psycopg2.connect(args.dsn) as conn: + with time_log("psycopg2"): + for id_ in to_query: + with conn.cursor() as cursor: + cursor.execute(select, {"id": id_}) + cursor.fetchall() + # conn.rollback() + + if args.drop: + logger.info("dropping test records") + with psycopg2.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + conn.commit() + + +def run_psycopg(psycopg: Any, args: Namespace) -> None: + logger.info("Running psycopg sync") + + if args.create: + logger.info(f"inserting {args.ntests} test records") + with psycopg.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + cursor.execute(table) + cursor.executemany(insert, data) + conn.commit() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + with psycopg.connect(args.dsn) as conn: + with time_log("psycopg"): + for id_ in to_query: + with conn.cursor() as cursor: + cursor.execute(select, {"id": id_}) + cursor.fetchall() + # conn.rollback() + + if args.drop: + logger.info("dropping test records") + with psycopg.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + conn.commit() + + +async def run_psycopg_async(psycopg: Any, args: Namespace) -> None: + logger.info("Running psycopg async") + + conn: Any + + if args.create: + logger.info(f"inserting {args.ntests} test records") + async with await psycopg.AsyncConnection.connect(args.dsn) as conn: + async with conn.cursor() as cursor: + await cursor.execute(drop) + await cursor.execute(table) + await cursor.executemany(insert, data) + await conn.commit() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + async with await psycopg.AsyncConnection.connect(args.dsn) as conn: + with time_log("psycopg_async"): + for id_ in to_query: + cursor = await conn.execute(select, {"id": id_}) + await cursor.fetchall() + await cursor.close() + # await conn.rollback() + + if args.drop: + logger.info("dropping test records") + async with await psycopg.AsyncConnection.connect(args.dsn) as conn: + async with conn.cursor() as cursor: + await cursor.execute(drop) + await conn.commit() + + +async def run_asyncpg(asyncpg: Any, args: Namespace) -> None: + logger.info("Running asyncpg") + + places = dict(id="$1", name="$2", description="$3", q="$4", p="$5", x="$6", y="$7") + a_insert = insert % places + a_select = select % {"id": "$1"} + + conn: Any + + if args.create: + logger.info(f"inserting {args.ntests} test records") + conn = await asyncpg.connect(args.dsn) + async with conn.transaction(): + await conn.execute(drop) + await conn.execute(table) + await conn.executemany(a_insert, [tuple(d.values()) for d in data]) + await conn.close() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + conn = await asyncpg.connect(args.dsn) + with time_log("asyncpg"): + for id_ in to_query: + tr = conn.transaction() + await tr.start() + await conn.fetch(a_select, id_) + # await tr.rollback() + await conn.close() + + if args.drop: + logger.info("dropping test records") + conn = await asyncpg.connect(args.dsn) + async with conn.transaction(): + await conn.execute(drop) + await conn.close() + + +def parse_cmdline() -> Namespace: + parser = ArgumentParser(description=__doc__) + parser.add_argument( + "drivers", + nargs="+", + metavar="DRIVER", + type=Driver, + help=f"the drivers to test [choices: {', '.join(d.value for d in Driver)}]", + ) + + parser.add_argument( + "--ntests", + type=int, + default=10_000, + help="number of tests to perform [default: %(default)s]", + ) + + parser.add_argument( + "--dsn", + default=os.environ.get("PSYCOPG_TEST_DSN", ""), + help="database connection string" + " [default: %(default)r (from PSYCOPG_TEST_DSN env var)]", + ) + + parser.add_argument( + "--no-create", + dest="create", + action="store_false", + default="True", + help="skip data creation before tests (it must exist already)", + ) + + parser.add_argument( + "--no-drop", + dest="drop", + action="store_false", + default="True", + help="skip data drop after tests", + ) + + opt = parser.parse_args() + + return opt + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/dectest.py b/tests/scripts/dectest.py new file mode 100644 index 0000000..a49f116 --- /dev/null +++ b/tests/scripts/dectest.py @@ -0,0 +1,51 @@ +""" +A quick and rough performance comparison of text vs. binary Decimal adaptation +""" +from random import randrange +from decimal import Decimal +import psycopg +from psycopg import sql + +ncols = 10 +nrows = 500000 +format = psycopg.pq.Format.BINARY +test = "copy" + + +def main() -> None: + cnn = psycopg.connect() + + cnn.execute( + sql.SQL("create table testdec ({})").format( + sql.SQL(", ").join( + [ + sql.SQL("{} numeric(10,2)").format(sql.Identifier(f"t{i}")) + for i in range(ncols) + ] + ) + ) + ) + cur = cnn.cursor() + + if test == "copy": + with cur.copy(f"copy testdec from stdin (format {format.name})") as copy: + for j in range(nrows): + copy.write_row( + [Decimal(randrange(10000000000)) / 100 for i in range(ncols)] + ) + + elif test == "insert": + ph = ["%t", "%b"][format] + cur.executemany( + "insert into testdec values (%s)" % ", ".join([ph] * ncols), + ( + [Decimal(randrange(10000000000)) / 100 for i in range(ncols)] + for j in range(nrows) + ), + ) + else: + raise Exception(f"bad test: {test}") + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py new file mode 100644 index 0000000..ec95229 --- /dev/null +++ b/tests/scripts/pipeline-demo.py @@ -0,0 +1,340 @@ +"""Pipeline mode demo + +This reproduces libpq_pipeline::pipelined_insert PostgreSQL test at +src/test/modules/libpq_pipeline/libpq_pipeline.c::test_pipelined_insert(). + +We do not fetch results explicitly (using cursor.fetch*()), this is +handled by execute() calls when pgconn socket is read-ready, which +happens when the output buffer is full. +""" +import argparse +import asyncio +import logging +from contextlib import contextmanager +from functools import partial +from typing import Any, Iterator, Optional, Sequence, Tuple + +from psycopg import AsyncConnection, Connection +from psycopg import pq, waiting +from psycopg import errors as e +from psycopg.abc import PipelineCommand +from psycopg.generators import pipeline_communicate +from psycopg.pq import Format, DiagnosticField +from psycopg._compat import Deque + +psycopg_logger = logging.getLogger("psycopg") +pipeline_logger = logging.getLogger("pipeline") +args: argparse.Namespace + + +class LoggingPGconn: + """Wrapper for PGconn that logs fetched results.""" + + def __init__(self, pgconn: pq.abc.PGconn, logger: logging.Logger): + self._pgconn = pgconn + self._logger = logger + + def log_notice(result: pq.abc.PGresult) -> None: + def get_field(field: DiagnosticField) -> Optional[str]: + value = result.error_field(field) + return value.decode("utf-8", "replace") if value else None + + logger.info( + "notice %s %s", + get_field(DiagnosticField.SEVERITY), + get_field(DiagnosticField.MESSAGE_PRIMARY), + ) + + pgconn.notice_handler = log_notice + + if args.trace: + self._trace_file = open(args.trace, "w") + pgconn.trace(self._trace_file.fileno()) + + def __del__(self) -> None: + if hasattr(self, "_trace_file"): + self._pgconn.untrace() + self._trace_file.close() + + def __getattr__(self, name: str) -> Any: + return getattr(self._pgconn, name) + + def send_query(self, command: bytes) -> None: + self._logger.warning("PQsendQuery broken in libpq 14.5") + self._pgconn.send_query(command) + self._logger.info("sent %s", command.decode()) + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + self._pgconn.send_query_params( + command, param_values, param_types, param_formats, result_format + ) + self._logger.info("sent %s", command.decode()) + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional[bytes]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + self._pgconn.send_query_prepared( + name, param_values, param_formats, result_format + ) + self._logger.info("sent prepared '%s' with %s", name.decode(), param_values) + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + self._pgconn.send_prepare(name, command, param_types) + self._logger.info("prepare %s as '%s'", command.decode(), name.decode()) + + def get_result(self) -> Optional[pq.abc.PGresult]: + r = self._pgconn.get_result() + if r is not None: + self._logger.info("got %s result", pq.ExecStatus(r.status).name) + return r + + +@contextmanager +def prepare_pipeline_demo_pq( + pgconn: LoggingPGconn, rows_to_send: int, logger: logging.Logger +) -> Iterator[Tuple[Deque[PipelineCommand], Deque[str]]]: + """Set up pipeline demo with initial queries and yield commands and + results queue for pipeline_communicate(). + """ + logger.debug("enter pipeline") + pgconn.enter_pipeline_mode() + + setup_queries = [ + ("begin", "BEGIN TRANSACTION"), + ("drop table", "DROP TABLE IF EXISTS pq_pipeline_demo"), + ( + "create table", + ( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ), + ), + ( + "prepare", + ("INSERT INTO pq_pipeline_demo(itemno, int8filler)" " VALUES ($1, $2)"), + ), + ] + + commands = Deque[PipelineCommand]() + results_queue = Deque[str]() + + for qname, query in setup_queries: + if qname == "prepare": + pgconn.send_prepare(qname.encode(), query.encode()) + else: + pgconn.send_query_params(query.encode(), None) + results_queue.append(qname) + + committed = False + synced = False + + while True: + if rows_to_send: + params = [f"{rows_to_send}".encode(), f"{1 << 62}".encode()] + commands.append(partial(pgconn.send_query_prepared, b"prepare", params)) + results_queue.append(f"row {rows_to_send}") + rows_to_send -= 1 + + elif not committed: + committed = True + commands.append(partial(pgconn.send_query_params, b"COMMIT", None)) + results_queue.append("commit") + + elif not synced: + + def sync() -> None: + pgconn.pipeline_sync() + logger.info("pipeline sync sent") + + synced = True + commands.append(sync) + results_queue.append("sync") + + else: + break + + try: + yield commands, results_queue + finally: + logger.debug("exit pipeline") + pgconn.exit_pipeline_mode() + + +def pipeline_demo_pq(rows_to_send: int, logger: logging.Logger) -> None: + pgconn = LoggingPGconn(Connection.connect().pgconn, logger) + with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as ( + commands, + results_queue, + ): + while results_queue: + fetched = waiting.wait( + pipeline_communicate( + pgconn, # type: ignore[arg-type] + commands, + ), + pgconn.socket, + ) + assert not commands, commands + for results in fetched: + results_queue.popleft() + for r in results: + if r.status in ( + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + ): + raise e.error_from_result(r) + + +async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> None: + pgconn = LoggingPGconn((await AsyncConnection.connect()).pgconn, logger) + + with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as ( + commands, + results_queue, + ): + while results_queue: + fetched = await waiting.wait_async( + pipeline_communicate( + pgconn, # type: ignore[arg-type] + commands, + ), + pgconn.socket, + ) + assert not commands, commands + for results in fetched: + results_queue.popleft() + for r in results: + if r.status in ( + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + ): + raise e.error_from_result(r) + + +def pipeline_demo(rows_to_send: int, many: bool, logger: logging.Logger) -> None: + """Pipeline demo using sync API.""" + conn = Connection.connect() + conn.autocommit = True + conn.pgconn = LoggingPGconn(conn.pgconn, logger) # type: ignore[assignment] + with conn.pipeline(): + with conn.transaction(): + conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + conn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)" + params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1)) + if many: + cur = conn.cursor() + cur.executemany(query, list(params)) + else: + for p in params: + conn.execute(query, p) + + +async def pipeline_demo_async( + rows_to_send: int, many: bool, logger: logging.Logger +) -> None: + """Pipeline demo using async API.""" + aconn = await AsyncConnection.connect() + await aconn.set_autocommit(True) + aconn.pgconn = LoggingPGconn(aconn.pgconn, logger) # type: ignore[assignment] + async with aconn.pipeline(): + async with aconn.transaction(): + await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + await aconn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)" + params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1)) + if many: + cur = aconn.cursor() + await cur.executemany(query, list(params)) + else: + for p in params: + await aconn.execute(query, p) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-n", + dest="nrows", + metavar="ROWS", + default=10_000, + type=int, + help="number of rows to insert", + ) + parser.add_argument( + "--pq", action="store_true", help="use low-level psycopg.pq API" + ) + parser.add_argument( + "--async", dest="async_", action="store_true", help="use async API" + ) + parser.add_argument( + "--many", + action="store_true", + help="use executemany() (not applicable for --pq)", + ) + parser.add_argument("--trace", help="write trace info into TRACE file") + parser.add_argument("-l", "--log", help="log file (stderr by default)") + + global args + args = parser.parse_args() + + psycopg_logger.setLevel(logging.DEBUG) + pipeline_logger.setLevel(logging.DEBUG) + if args.log: + psycopg_logger.addHandler(logging.FileHandler(args.log)) + pipeline_logger.addHandler(logging.FileHandler(args.log)) + else: + psycopg_logger.addHandler(logging.StreamHandler()) + pipeline_logger.addHandler(logging.StreamHandler()) + + if args.pq: + if args.many: + parser.error("--many cannot be used with --pq") + if args.async_: + asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger)) + else: + pipeline_demo_pq(args.nrows, pipeline_logger) + else: + if pq.__impl__ != "python": + parser.error( + "only supported for Python implementation (set PSYCOPG_IMPL=python)" + ) + if args.async_: + asyncio.run(pipeline_demo_async(args.nrows, args.many, pipeline_logger)) + else: + pipeline_demo(args.nrows, args.many, pipeline_logger) + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/spiketest.py b/tests/scripts/spiketest.py new file mode 100644 index 0000000..2c9cc16 --- /dev/null +++ b/tests/scripts/spiketest.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +""" +Run a connection pool spike test. + +The test is inspired to the `spike analysis`__ illustrated by HikariCP + +.. __: https://github.com/brettwooldridge/HikariCP/blob/dev/documents/ + Welcome-To-The-Jungle.md + +""" +# mypy: allow-untyped-defs +# mypy: allow-untyped-calls + +import time +import threading + +import psycopg +import psycopg_pool +from psycopg.rows import Row + +import logging + + +def main() -> None: + opt = parse_cmdline() + if opt.loglevel: + loglevel = getattr(logging, opt.loglevel.upper()) + logging.basicConfig( + level=loglevel, format="%(asctime)s %(levelname)s %(message)s" + ) + + logging.getLogger("psycopg2.pool").setLevel(loglevel) + + with psycopg_pool.ConnectionPool( + opt.dsn, + min_size=opt.min_size, + max_size=opt.max_size, + connection_class=DelayedConnection, + kwargs={"conn_delay": 0.150}, + ) as pool: + pool.wait() + measurer = Measurer(pool) + + # Create and start all the thread: they will get stuck on the event + ev = threading.Event() + threads = [ + threading.Thread(target=worker, args=(pool, 0.002, ev), daemon=True) + for i in range(opt.num_clients) + ] + for t in threads: + t.start() + time.sleep(0.2) + + # Release the threads! + measurer.start(0.00025) + t0 = time.time() + ev.set() + + # Wait for the threads to finish + for t in threads: + t.join() + t1 = time.time() + measurer.stop() + + print(f"time: {(t1 - t0) * 1000} msec") + print("active,idle,total,waiting") + recs = [ + f'{m["pool_size"] - m["pool_available"]}' + f',{m["pool_available"]}' + f',{m["pool_size"]}' + f',{m["requests_waiting"]}' + for m in measurer.measures + ] + print("\n".join(recs)) + + +def worker(p, t, ev): + ev.wait() + with p.connection(): + time.sleep(t) + + +class Measurer: + def __init__(self, pool): + self.pool = pool + self.worker = None + self.stopped = False + self.measures = [] + + def start(self, interval): + self.worker = threading.Thread(target=self._run, args=(interval,), daemon=True) + self.worker.start() + + def stop(self): + self.stopped = True + if self.worker: + self.worker.join() + self.worker = None + + def _run(self, interval): + while not self.stopped: + self.measures.append(self.pool.get_stats()) + time.sleep(interval) + + +class DelayedConnection(psycopg.Connection[Row]): + """A connection adding a delay to the connection time.""" + + @classmethod + def connect(cls, conninfo, conn_delay=0, **kwargs): + t0 = time.time() + conn = super().connect(conninfo, **kwargs) + t1 = time.time() + wait = max(0.0, conn_delay - (t1 - t0)) + if wait: + time.sleep(wait) + return conn + + +def parse_cmdline(): + from argparse import ArgumentParser + + parser = ArgumentParser(description=__doc__) + parser.add_argument("--dsn", default="", help="connection string to the database") + parser.add_argument( + "--min_size", + default=5, + type=int, + help="minimum number of connections in the pool", + ) + parser.add_argument( + "--max_size", + default=20, + type=int, + help="maximum number of connections in the pool", + ) + parser.add_argument( + "--num-clients", + default=50, + type=int, + help="number of threads making a request", + ) + parser.add_argument( + "--loglevel", + default=None, + choices=("DEBUG", "INFO", "WARNING", "ERROR"), + help="level to log at [default: no log]", + ) + + opt = parser.parse_args() + + return opt + + +if __name__ == "__main__": + main() diff --git a/tests/test_adapt.py b/tests/test_adapt.py new file mode 100644 index 0000000..2190a84 --- /dev/null +++ b/tests/test_adapt.py @@ -0,0 +1,530 @@ +import datetime as dt +from types import ModuleType +from typing import Any, List + +import pytest + +import psycopg +from psycopg import pq, sql, postgres +from psycopg import errors as e +from psycopg.adapt import Transformer, PyFormat, Dumper, Loader +from psycopg._cmodule import _psycopg +from psycopg.postgres import types as builtins, TEXT_OID +from psycopg.types.array import ListDumper, ListBinaryDumper + + +@pytest.mark.parametrize( + "data, format, result, type", + [ + (1, PyFormat.TEXT, b"1", "int2"), + ("hello", PyFormat.TEXT, b"hello", "text"), + ("hello", PyFormat.BINARY, b"hello", "text"), + ], +) +def test_dump(data, format, result, type): + t = Transformer() + dumper = t.get_dumper(data, format) + assert dumper.dump(data) == result + if type == "text" and format != PyFormat.BINARY: + assert dumper.oid == 0 + else: + assert dumper.oid == builtins[type].oid + + +@pytest.mark.parametrize( + "data, result", + [ + (1, b"1"), + ("hello", b"'hello'"), + ("he'llo", b"'he''llo'"), + (True, b"true"), + (None, b"NULL"), + ], +) +def test_quote(data, result): + t = Transformer() + dumper = t.get_dumper(data, PyFormat.TEXT) + assert dumper.quote(data) == result + + +def test_register_dumper_by_class(conn): + dumper = make_dumper("x") + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper + conn.adapters.register_dumper(MyStr, dumper) + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper + + +def test_register_dumper_by_class_name(conn): + dumper = make_dumper("x") + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper + conn.adapters.register_dumper(f"{MyStr.__module__}.{MyStr.__qualname__}", dumper) + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper + + +@pytest.mark.crdb("skip", reason="global adapters don't affect crdb") +def test_dump_global_ctx(conn_cls, dsn, global_adapters, pgconn): + psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb")) + psycopg.adapters.register_dumper(MyStr, make_dumper("gt")) + with conn_cls.connect(dsn) as conn: + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + + +def test_dump_connection_ctx(conn): + conn.adapters.register_dumper(MyStr, make_bin_dumper("b")) + conn.adapters.register_dumper(MyStr, make_dumper("t")) + + cur = conn.cursor() + cur.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellob",) + + +def test_dump_cursor_ctx(conn): + conn.adapters.register_dumper(str, make_bin_dumper("b")) + conn.adapters.register_dumper(str, make_dumper("t")) + + cur = conn.cursor() + cur.adapters.register_dumper(str, make_bin_dumper("bc")) + cur.adapters.register_dumper(str, make_dumper("tc")) + + cur.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellotc",) + cur.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellotc",) + cur.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellobc",) + + cur = conn.cursor() + cur.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellob",) + + +def test_dump_subclass(conn): + class MyString(str): + pass + + cur = conn.cursor() + cur.execute("select %s::text, %b::text", [MyString("hello"), MyString("world")]) + assert cur.fetchone() == ("hello", "world") + + +def test_subclass_dumper(conn): + # This might be a C fast object: make sure that the Python code is called + from psycopg.types.string import StrDumper + + class MyStrDumper(StrDumper): + def dump(self, obj): + return (obj * 2).encode() + + conn.adapters.register_dumper(str, MyStrDumper) + assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello" + + +def test_dumper_protocol(conn): + + # This class doesn't inherit from adapt.Dumper but passes a mypy check + from .adapters_example import MyStrDumper + + conn.adapters.register_dumper(str, MyStrDumper) + cur = conn.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellohello" + cur = conn.execute("select %s", [["hi", "ha"]]) + assert cur.fetchone()[0] == ["hihi", "haha"] + assert sql.Literal("hello").as_string(conn) == "'qelloqello'" + + +def test_loader_protocol(conn): + + # This class doesn't inherit from adapt.Loader but passes a mypy check + from .adapters_example import MyTextLoader + + conn.adapters.register_loader("text", MyTextLoader) + cur = conn.execute("select 'hello'::text") + assert cur.fetchone()[0] == "hellohello" + cur = conn.execute("select '{hi,ha}'::text[]") + assert cur.fetchone()[0] == ["hihi", "haha"] + + +def test_subclass_loader(conn): + # This might be a C fast object: make sure that the Python code is called + from psycopg.types.string import TextLoader + + class MyTextLoader(TextLoader): + def load(self, data): + return (bytes(data) * 2).decode() + + conn.adapters.register_loader("text", MyTextLoader) + assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello" + + +@pytest.mark.parametrize( + "data, format, type, result", + [ + (b"1", pq.Format.TEXT, "int4", 1), + (b"hello", pq.Format.TEXT, "text", "hello"), + (b"hello", pq.Format.BINARY, "text", "hello"), + ], +) +def test_cast(data, format, type, result): + t = Transformer() + rv = t.get_loader(builtins[type].oid, format).load(data) + assert rv == result + + +def test_register_loader_by_oid(conn): + assert TEXT_OID == 25 + loader = make_loader("x") + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader + conn.adapters.register_loader(TEXT_OID, loader) + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader + + +def test_register_loader_by_type_name(conn): + loader = make_loader("x") + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader + conn.adapters.register_loader("text", loader) + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader + + +@pytest.mark.crdb("skip", reason="global adapters don't affect crdb") +def test_load_global_ctx(conn_cls, dsn, global_adapters): + psycopg.adapters.register_loader("text", make_loader("gt")) + psycopg.adapters.register_loader("text", make_bin_loader("gb")) + with conn_cls.connect(dsn) as conn: + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) + + +def test_load_connection_ctx(conn): + conn.adapters.register_loader("text", make_loader("t")) + conn.adapters.register_loader("text", make_bin_loader("b")) + + r = conn.cursor(binary=False).execute("select 'hello'::text").fetchone() + assert r == ("hellot",) + r = conn.cursor(binary=True).execute("select 'hello'::text").fetchone() + assert r == ("hellob",) + + +def test_load_cursor_ctx(conn): + conn.adapters.register_loader("text", make_loader("t")) + conn.adapters.register_loader("text", make_bin_loader("b")) + + cur = conn.cursor() + cur.adapters.register_loader("text", make_loader("tc")) + cur.adapters.register_loader("text", make_bin_loader("bc")) + + assert cur.execute("select 'hello'::text").fetchone() == ("hellotc",) + cur.format = pq.Format.BINARY + assert cur.execute("select 'hello'::text").fetchone() == ("hellobc",) + + cur = conn.cursor() + assert cur.execute("select 'hello'::text").fetchone() == ("hellot",) + cur.format = pq.Format.BINARY + assert cur.execute("select 'hello'::text").fetchone() == ("hellob",) + + +def test_cow_dumpers(conn): + conn.adapters.register_dumper(str, make_dumper("t")) + + cur1 = conn.cursor() + cur2 = conn.cursor() + cur2.adapters.register_dumper(str, make_dumper("c2")) + + r = cur1.execute("select %s::text -- 1", ["hello"]).fetchone() + assert r == ("hellot",) + r = cur2.execute("select %s::text -- 1", ["hello"]).fetchone() + assert r == ("helloc2",) + + conn.adapters.register_dumper(str, make_dumper("t1")) + r = cur1.execute("select %s::text -- 2", ["hello"]).fetchone() + assert r == ("hellot",) + r = cur2.execute("select %s::text -- 2", ["hello"]).fetchone() + assert r == ("helloc2",) + + +def test_cow_loaders(conn): + conn.adapters.register_loader("text", make_loader("t")) + + cur1 = conn.cursor() + cur2 = conn.cursor() + cur2.adapters.register_loader("text", make_loader("c2")) + + assert cur1.execute("select 'hello'::text").fetchone() == ("hellot",) + assert cur2.execute("select 'hello'::text").fetchone() == ("helloc2",) + + conn.adapters.register_loader("text", make_loader("t1")) + assert cur1.execute("select 'hello2'::text").fetchone() == ("hello2t",) + assert cur2.execute("select 'hello2'::text").fetchone() == ("hello2c2",) + + +@pytest.mark.parametrize( + "sql, obj", + [("'{hello}'::text[]", ["helloc"]), ("row('hello'::text)", ("helloc",))], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out): + cur = conn.cursor(binary=fmt_out == pq.Format.BINARY) + if fmt_out == pq.Format.TEXT: + cur.adapters.register_loader("text", make_loader("c")) + else: + cur.adapters.register_loader("text", make_bin_loader("c")) + + cur.execute(f"select {sql}") + res = cur.fetchone()[0] + assert res == obj + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_list_dumper(conn, fmt_out): + t = Transformer(conn) + fmt_in = PyFormat.from_pq(fmt_out) + dint = t.get_dumper([0], fmt_in) + assert isinstance(dint, (ListDumper, ListBinaryDumper)) + assert dint.oid == builtins["int2"].array_oid + assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid + + dstr = t.get_dumper([""], fmt_in) + assert dstr is not dint + + assert t.get_dumper([1], fmt_in) is dint + assert t.get_dumper([None, [1]], fmt_in) is dint + + dempty = t.get_dumper([], fmt_in) + assert t.get_dumper([None, [None]], fmt_in) is dempty + assert dempty.oid == 0 + assert dempty.dump([]) == b"{}" + + L: List[List[Any]] = [] + L.append(L) + with pytest.raises(psycopg.DataError): + assert t.get_dumper(L, fmt_in) + + +@pytest.mark.crdb("skip", reason="test in crdb test suite") +def test_str_list_dumper_text(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.TEXT) + assert isinstance(dstr, ListDumper) + assert dstr.oid == 0 + assert dstr.sub_dumper and dstr.sub_dumper.oid == 0 + + +def test_str_list_dumper_binary(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.BINARY) + assert isinstance(dstr, ListBinaryDumper) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid + + +def test_last_dumper_registered_ctx(conn): + cur = conn.cursor() + + bd = make_bin_dumper("b") + cur.adapters.register_dumper(str, bd) + td = make_dumper("t") + cur.adapters.register_dumper(str, td) + + assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellot" + assert cur.execute("select %t", ["hello"]).fetchone()[0] == "hellot" + assert cur.execute("select %b", ["hello"]).fetchone()[0] == "hellob" + + cur.adapters.register_dumper(str, bd) + assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellob" + + +@pytest.mark.parametrize("fmt_in", [PyFormat.TEXT, PyFormat.BINARY]) +def test_none_type_argument(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table none_args (id serial primary key, num integer)") + cur.execute("insert into none_args (num) values (%s) returning id", (None,)) + assert cur.fetchone()[0] + + +@pytest.mark.crdb("skip", reason="test in crdb test suite") +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_return_untyped(conn, fmt_in): + # Analyze and check for changes using strings in untyped/typed contexts + cur = conn.cursor() + # Currently string are passed as unknown oid to libpq. This is because + # unknown is more easily cast by postgres to different types (see jsonb + # later). + cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10]) + assert cur.fetchone() == ("hello", 10) + + cur.execute("create table testjson(data jsonb)") + if fmt_in != PyFormat.BINARY: + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + assert cur.execute("select data from testjson").fetchone() == ({},) + else: + # Binary types cannot be passed as unknown oids. + with pytest.raises(e.DatatypeMismatch): + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_no_cast_needed(conn, fmt_in): + # Verify that there is no need of cast in certain common scenario + cur = conn.execute(f"select '2021-01-01'::date + %{fmt_in.value}", [3]) + assert cur.fetchone()[0] == dt.date(2021, 1, 4) + + cur = conn.execute(f"select '[10, 20, 30]'::jsonb -> %{fmt_in.value}", [1]) + assert cur.fetchone()[0] == 20 + + +@pytest.mark.slow +@pytest.mark.skipif(_psycopg is None, reason="C module test") +def test_optimised_adapters(): + + # All the optimised adapters available + c_adapters = {} + for n in dir(_psycopg): + if n.startswith("_") or n in ("CDumper", "CLoader"): + continue + obj = getattr(_psycopg, n) + if not isinstance(obj, type): + continue + if not issubclass( + obj, + (_psycopg.CDumper, _psycopg.CLoader), # type: ignore[attr-defined] + ): + continue + c_adapters[n] = obj + + # All the registered adapters + reg_adapters = set() + adapters = list(postgres.adapters._dumpers.values()) + postgres.adapters._loaders + assert len(adapters) == 5 + for m in adapters: + reg_adapters |= set(m.values()) + + # Check that the registered adapters are the optimised one + i = 0 + for cls in reg_adapters: + if cls.__name__ in c_adapters: + assert cls is c_adapters[cls.__name__] + i += 1 + + assert i >= 10 + + # Check that every optimised adapter is the optimised version of a Py one + for n in dir(psycopg.types): + mod = getattr(psycopg.types, n) + if not isinstance(mod, ModuleType): + continue + for n1 in dir(mod): + obj = getattr(mod, n1) + if not isinstance(obj, type): + continue + if not issubclass(obj, (Dumper, Loader)): + continue + c_adapters.pop(obj.__name__, None) + + assert not c_adapters + + +def test_dumper_init_error(conn): + class BadDumper(Dumper): + def __init__(self, cls, context): + super().__init__(cls, context) + 1 / 0 + + def dump(self, obj): + return obj.encode() + + cur = conn.cursor() + cur.adapters.register_dumper(str, BadDumper) + with pytest.raises(ZeroDivisionError): + cur.execute("select %s::text", ["hi"]) + + +def test_loader_init_error(conn): + class BadLoader(Loader): + def __init__(self, oid, context): + super().__init__(oid, context) + 1 / 0 + + def load(self, data): + return data.decode() + + cur = conn.cursor() + cur.adapters.register_loader("text", BadLoader) + with pytest.raises(ZeroDivisionError): + cur.execute("select 'hi'::text") + assert cur.fetchone() == ("hi",) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt", PyFormat) +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_random(conn, faker, fmt, fmt_out): + faker.format = fmt + faker.choose_schema(ncols=20) + faker.make_records(50) + + with conn.cursor(binary=fmt_out) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + cur.execute(faker.select_stmt) + recs = cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + +class MyStr(str): + pass + + +def make_dumper(suffix): + """Create a test dumper appending a suffix to the bytes representation.""" + + class TestDumper(Dumper): + oid = TEXT_OID + format = pq.Format.TEXT + + def dump(self, s): + return (s + suffix).encode("ascii") + + return TestDumper + + +def make_bin_dumper(suffix): + cls = make_dumper(suffix) + cls.format = pq.Format.BINARY + return cls + + +def make_loader(suffix): + """Create a test loader appending a suffix to the data returned.""" + + class TestLoader(Loader): + format = pq.Format.TEXT + + def load(self, b): + return bytes(b).decode("ascii") + suffix + + return TestLoader + + +def make_bin_loader(suffix): + cls = make_loader(suffix) + cls.format = pq.Format.BINARY + return cls diff --git a/tests/test_client_cursor.py b/tests/test_client_cursor.py new file mode 100644 index 0000000..b355604 --- /dev/null +++ b/tests/test_client_cursor.py @@ -0,0 +1,855 @@ +import pickle +import weakref +import datetime as dt +from typing import List + +import pytest + +import psycopg +from psycopg import sql, rows +from psycopg.adapt import PyFormat +from psycopg.postgres import types as builtins + +from .utils import gc_collect, gc_count +from .test_cursor import my_row_factory +from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision + + +@pytest.fixture +def conn(conn): + conn.cursor_factory = psycopg.ClientCursor + return conn + + +def test_init(conn): + cur = psycopg.ClientCursor(conn) + cur.execute("select 1") + assert cur.fetchone() == (1,) + + conn.row_factory = rows.dict_row + cur = psycopg.ClientCursor(conn) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_init_factory(conn): + cur = psycopg.ClientCursor(conn, row_factory=rows.dict_row) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_from_cursor_factory(conn_cls, dsn): + with conn_cls.connect(dsn, cursor_factory=psycopg.ClientCursor) as conn: + cur = conn.cursor() + assert type(cur) is psycopg.ClientCursor + + cur.execute("select %s", (1,)) + assert cur.fetchone() == (1,) + assert cur._query + assert cur._query.query == b"select 1" + + +def test_close(conn): + cur = conn.cursor() + assert not cur.closed + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.execute("select 'foo'") + + cur.close() + assert cur.closed + + +def test_cursor_close_fetchone(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + for _ in range(5): + cur.fetchone() + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchone() + + +def test_cursor_close_fetchmany(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchmany(2)) == 2 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchmany(2) + + +def test_cursor_close_fetchall(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchall()) == 10 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchall() + + +def test_context(conn): + with conn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +def test_weakref(conn): + cur = conn.cursor() + w = weakref.ref(cur) + cur.close() + del cur + gc_collect() + assert w() is None + + +def test_pgresult(conn): + cur = conn.cursor() + cur.execute("select 1") + assert cur.pgresult + cur.close() + assert not cur.pgresult + + +def test_statusmessage(conn): + cur = conn.cursor() + assert cur.statusmessage is None + + cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + cur.execute("wat") + assert cur.statusmessage is None + + +def test_execute_sql(conn): + cur = conn.cursor() + cur.execute(sql.SQL("select {value}").format(value="hello")) + assert cur.fetchone() == ("hello",) + + +def test_execute_many_results(conn): + cur = conn.cursor() + assert cur.nextset() is None + + rv = cur.execute("select %s; select generate_series(1,%s)", ("foo", 3)) + assert rv is cur + assert cur.fetchall() == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.nextset() is None + + cur.close() + assert cur.nextset() is None + + +def test_execute_sequence(conn): + cur = conn.cursor() + rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +def test_execute_empty_query(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + +def test_execute_type_change(conn): + # issue #112 + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.execute(sql, (1,)) + cur.execute(sql, (100_000,)) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +def test_executemany_type_change(conn): + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.executemany(sql, [(1,), (100_000,)]) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +def test_execute_copy(conn, query): + cur = conn.cursor() + cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + cur.execute(query) + + +def test_fetchone(conn): + cur = conn.cursor() + cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = cur.fetchone() + assert row == (1, "foo", None) + row = cur.fetchone() + assert row is None + + +def test_binary_cursor_execute(conn): + with pytest.raises(psycopg.NotSupportedError): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None]) + + +def test_execute_binary(conn): + with pytest.raises(psycopg.NotSupportedError): + cur = conn.cursor() + cur.execute("select %s, %s", [1, None], binary=True) + + +def test_binary_cursor_text_override(conn): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None], binary=False) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +def test_query_encode(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("select '\u20ac'").fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +def test_query_badenc(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + with pytest.raises(UnicodeEncodeError): + cur.execute("select '\u20ac'") + + +@pytest.fixture(scope="session") +def _execmany(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop table if exists execmany; + create table execmany (id serial primary key, num integer, data text) + """ + ) + + +@pytest.fixture(scope="function") +def execmany(svcconn, _execmany): + cur = svcconn.cursor() + cur.execute("truncate table execmany") + + +def test_executemany(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(10, "hello"), (20, "world")] + + +def test_executemany_name(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(11, "hello"), (21, "world")] + + +def test_executemany_no_data(conn, execmany): + cur = conn.cursor() + cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +def test_executemany_rowcount(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +def test_executemany_returning(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.fetchone() == (10,) + assert cur.nextset() + assert cur.fetchone() == (20,) + assert cur.nextset() is None + + +def test_executemany_returning_discard(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + assert cur.nextset() is None + + +def test_executemany_no_result(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +def test_executemany_rowcount_no_hit(conn, execmany): + cur = conn.cursor() + cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)]) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + # This fails, but only because we try to copy in pipeline mode, + # crashing the connection. Which would be even fine, but with + # the async cursor it's worse... See test_client_cursor_async.py. + # "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +def test_executemany_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_executemany_null_first(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table testmany (a bigint, b bigint)") + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +def test_rowcount(conn): + cur = conn.cursor() + + cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)") + assert cur.rowcount == 42 + + +def test_rownumber(conn): + cur = conn.cursor() + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +def test_iter(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + assert list(cur) == [(1,), (2,), (3,)] + + +def test_iter_stop(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + for rec in cur: + assert rec == (1,) + break + + for rec in cur: + assert rec == (2,) + break + + assert cur.fetchone() == (3,) + assert list(cur) == [] + + +def test_row_factory(conn): + cur = conn.cursor(row_factory=my_row_factory) + + cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + cur.execute("select 'foo' as bar") + (r,) = cur.fetchone() + assert r == "FOObar" + + cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert cur.fetchall() == [["Yy", "Zz"]] + + cur.scroll(-1) + cur.row_factory = rows.dict_row + assert cur.fetchone() == {"y": "y", "z": "z"} + + +def test_row_factory_none(conn): + cur = conn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + r = cur.execute("select 1 as a, 2 as b").fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +def test_bad_row_factory(conn): + def broken_factory(cur): + 1 / 0 + + cur = conn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = conn.cursor(row_factory=broken_maker) + cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + cur.fetchone() + + +def test_scroll(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(-1) + assert cur.fetchone() == (8,) + cur.scroll(-2) + assert cur.fetchone() == (7,) + cur.scroll(2, mode="absolute") + assert cur.fetchone() == (2,) + + # on the boundary + cur.scroll(0, mode="absolute") + assert cur.fetchone() == (0,) + with pytest.raises(IndexError): + cur.scroll(-1, mode="absolute") + + cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(-1) + + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + with pytest.raises(IndexError): + cur.scroll(10, mode="absolute") + + cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(1) + + with pytest.raises(ValueError): + cur.scroll(1, "wat") + + +def test_query_params_execute(conn): + cur = conn.cursor() + assert cur._query is None + + cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select 1, NULL::text" + assert cur._query.params == (b"1", b"NULL") + + cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select 'wat'::int" + assert cur._query.params == (b"'wat'",) + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select %(x)s", {"x": 1}, (1,)), + ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)), + ("select %(x)s, %(x)s", {"x": 1}, (1, 1)), + ], +) +def test_query_params_named(conn, query, params, want): + cur = conn.cursor() + cur.execute(query, params) + rec = cur.fetchone() + assert rec == want + + +def test_query_params_executemany(conn): + cur = conn.cursor() + + cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select 3, 4" + assert cur._query.params == (b"3", b"4") + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +def test_copy_out_param(conn, ph, params): + cur = conn.cursor() + with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert list(copy.rows()) == [(i + 1,) for i in range(10)] + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_stream(conn): + cur = conn.cursor() + recs = [] + for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +class TestColumn: + def test_description_attribs(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + assert len(curs.description) == 3 + for c in curs.description: + len(c) == 7 # DBAPI happy + for i, a in enumerate( + """ + name type_code display_size internal_size precision scale null_ok + """.split() + ): + assert c[i] == getattr(c, a) + + # Won't fill them up + assert c.null_ok is None + + c = curs.description[0] + assert c.name == "pi" + assert c.type_code == builtins["numeric"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision == 10 + assert c.scale == 2 + + c = curs.description[1] + assert c.name == "hi" + assert c.type_code == builtins["text"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision is None + assert c.scale is None + + c = curs.description[2] + assert c.name == "now" + assert c.type_code == builtins["date"].oid + assert c.display_size is None + if is_crdb(conn): + assert c.internal_size == 16 + else: + assert c.internal_size == 4 + assert c.precision is None + assert c.scale is None + + def test_description_slice(self, conn): + curs = conn.cursor() + curs.execute("select 1::int as a") + curs.description[0][0:2] == ("a", 23) + + @pytest.mark.parametrize( + "type, precision, scale, dsize, isize", + [ + ("text", None, None, None, None), + ("varchar", None, None, None, None), + ("varchar(42)", None, None, 42, None), + ("int4", None, None, None, 4), + ("numeric", None, None, None, None), + ("numeric(10)", 10, 0, None, None), + ("numeric(10, 3)", 10, 3, None, None), + ("time", None, None, None, 8), + crdb_time_precision("time(4)", 4, None, None, 8), + crdb_time_precision("time(10)", 6, None, None, 8), + ], + ) + def test_details(self, conn, type, precision, scale, dsize, isize): + cur = conn.cursor() + cur.execute(f"select null::{type}") + col = cur.description[0] + repr(col) + assert col.precision == precision + assert col.scale == scale + assert col.display_size == dsize + assert col.internal_size == isize + + def test_pickle(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + description = curs.description + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) + assert [tuple(d) for d in description] == [tuple(d) for d in unpickled] + + @pytest.mark.crdb_skip("no col query") + def test_no_col_query(self, conn): + cur = conn.execute("select") + assert cur.description == [] + assert cur.fetchall() == [()] + + def test_description_closed_connection(self, conn): + # If we have reasons to break this test we will (e.g. we really need + # the connection). In #172 it fails just by accident. + cur = conn.execute("select 1::int4 as foo") + conn.close() + assert len(cur.description) == 1 + col = cur.description[0] + assert col.name == "foo" + assert col.type_code == 23 + + def test_name_not_a_name(self, conn): + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone() + assert res == "x" + assert cur.description[0].name == "foo-bar" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_name_encode(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone() + assert res == "x" + assert cur.description[0].name == "\u20ac" + + +def test_str(conn): + cur = conn.cursor() + assert "psycopg.ClientCursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_leak(conn_cls, dsn, faker, fetch, row_factory): + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + def work(): + with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True): + with psycopg.ClientCursor(conn, row_factory=row_factory) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + cur.fetchall() + elif fetch == "iter": + for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select 'hello'", (), "select 'hello'"), + ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"), + ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"), + ("select %%", (), "select %%"), + ("select %%, %s", (["a"],), "select %, 'a'"), + ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"), + ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"), + ], +) +def test_mogrify(conn, query, params, want): + cur = conn.cursor() + got = cur.mogrify(query, *params) + assert got == want + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +def test_mogrify_encoding(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + q = conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + assert q == "select '\u20ac'" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +def test_mogrify_badenc(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + with pytest.raises(UnicodeEncodeError): + conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + + +@pytest.mark.pipeline +def test_message_0x33(conn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + conn.autocommit = True + with conn.pipeline(): + cur = conn.execute("select 'test'") + assert cur.fetchone() == ("test",) + + assert not notices diff --git a/tests/test_client_cursor_async.py b/tests/test_client_cursor_async.py new file mode 100644 index 0000000..0cf8ec6 --- /dev/null +++ b/tests/test_client_cursor_async.py @@ -0,0 +1,727 @@ +import pytest +import weakref +import datetime as dt +from typing import List + +import psycopg +from psycopg import sql, rows +from psycopg.adapt import PyFormat + +from .utils import alist, gc_collect, gc_count +from .test_cursor import my_row_factory +from .test_cursor import execmany, _execmany # noqa: F401 +from .fix_crdb import crdb_encoding + +execmany = execmany # avoid F811 underneath +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +async def aconn(aconn): + aconn.cursor_factory = psycopg.AsyncClientCursor + return aconn + + +async def test_init(aconn): + cur = psycopg.AsyncClientCursor(aconn) + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + + aconn.row_factory = rows.dict_row + cur = psycopg.AsyncClientCursor(aconn) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_init_factory(aconn): + cur = psycopg.AsyncClientCursor(aconn, row_factory=rows.dict_row) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_from_cursor_factory(aconn_cls, dsn): + async with await aconn_cls.connect( + dsn, cursor_factory=psycopg.AsyncClientCursor + ) as aconn: + cur = aconn.cursor() + assert type(cur) is psycopg.AsyncClientCursor + + await cur.execute("select %s", (1,)) + assert await cur.fetchone() == (1,) + assert cur._query + assert cur._query.query == b"select 1" + + +async def test_close(aconn): + cur = aconn.cursor() + assert not cur.closed + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.execute("select 'foo'") + + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchall() + + +async def test_context(aconn): + async with aconn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +async def test_weakref(aconn): + cur = aconn.cursor() + w = weakref.ref(cur) + await cur.close() + del cur + gc_collect() + assert w() is None + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_statusmessage(aconn): + cur = aconn.cursor() + assert cur.statusmessage is None + + await cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + await cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + await cur.execute("wat") + assert cur.statusmessage is None + + +async def test_execute_sql(aconn): + cur = aconn.cursor() + await cur.execute(sql.SQL("select {value}").format(value="hello")) + assert await cur.fetchone() == ("hello",) + + +async def test_execute_many_results(aconn): + cur = aconn.cursor() + assert cur.nextset() is None + + rv = await cur.execute("select %s; select generate_series(1,%s)", ("foo", 3)) + assert rv is cur + assert (await cur.fetchall()) == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert (await cur.fetchall()) == [(1,), (2,), (3,)] + assert cur.rowcount == 3 + assert cur.nextset() is None + + await cur.close() + assert cur.nextset() is None + + +async def test_execute_sequence(aconn): + cur = aconn.cursor() + rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +async def test_execute_empty_query(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + + +async def test_execute_type_change(aconn): + # issue #112 + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.execute(sql, (1,)) + await cur.execute(sql, (100_000,)) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +async def test_executemany_type_change(aconn): + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.executemany(sql, [(1,), (100_000,)]) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +async def test_execute_copy(aconn, query): + cur = aconn.cursor() + await cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + await cur.execute(query) + + +async def test_fetchone(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = await cur.fetchone() + assert row == (1, "foo", None) + row = await cur.fetchone() + assert row is None + + +async def test_binary_cursor_execute(aconn): + with pytest.raises(psycopg.NotSupportedError): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None]) + + +async def test_execute_binary(aconn): + with pytest.raises(psycopg.NotSupportedError): + cur = aconn.cursor() + await cur.execute("select %s, %s", [1, None], binary=True) + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None], binary=False) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +async def test_query_encode(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + await cur.execute("select '\u20ac'") + (res,) = await cur.fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +async def test_query_badenc(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + with pytest.raises(UnicodeEncodeError): + await cur.execute("select '\u20ac'") + + +async def test_executemany(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(10, "hello"), (20, "world")] + + +async def test_executemany_name(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(11, "hello"), (21, "world")] + + +async def test_executemany_no_data(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +async def test_executemany_rowcount(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +async def test_executemany_returning(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert (await cur.fetchone()) == (10,) + assert cur.nextset() + assert (await cur.fetchone()) == (20,) + assert cur.nextset() is None + + +async def test_executemany_returning_discard(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + assert cur.nextset() is None + + +async def test_executemany_no_result(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +async def test_executemany_rowcount_no_hit(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + await cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + await cur.executemany( + "delete from execmany where id = %s returning num", [(-1,), (-2,)] + ) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + # This fails because we end up trying to copy in pipeline mode. + # However, sometimes (and pretty regularly if we enable pgconn.trace()) + # something goes in a loop and only terminates by OOM. Strace shows + # an allocation loop. I think it's in the libpq. + # "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +async def test_executemany_badquery(aconn, query): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +async def test_executemany_null_first(aconn, fmt_in): + cur = aconn.cursor() + await cur.execute("create table testmany (a bigint, b bigint)") + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +async def test_rowcount(aconn): + cur = aconn.cursor() + + await cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + await cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + await cur.execute( + "insert into test_rowcount_notuples select generate_series(1, 42)" + ) + assert cur.rowcount == 42 + + +async def test_rownumber(aconn): + cur = aconn.cursor() + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + async for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(await cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +async def test_iter(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + res = [] + async for rec in cur: + res.append(rec) + assert res == [(1,), (2,), (3,)] + + +async def test_iter_stop(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + async for rec in cur: + assert rec == (1,) + break + + async for rec in cur: + assert rec == (2,) + break + + assert (await cur.fetchone()) == (3,) + async for rec in cur: + assert False + + +async def test_row_factory(aconn): + cur = aconn.cursor(row_factory=my_row_factory) + await cur.execute("select 'foo' as bar") + (r,) = await cur.fetchone() + assert r == "FOObar" + + await cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert await cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert await cur.fetchall() == [["Yy", "Zz"]] + + await cur.scroll(-1) + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"y": "y", "z": "z"} + + +async def test_row_factory_none(aconn): + cur = aconn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + await cur.execute("select 1 as a, 2 as b") + r = await cur.fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +async def test_bad_row_factory(aconn): + def broken_factory(cur): + 1 / 0 + + cur = aconn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + await cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = aconn.cursor(row_factory=broken_maker) + await cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + await cur.fetchone() + + +async def test_scroll(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(-1) + assert await cur.fetchone() == (8,) + await cur.scroll(-2) + assert await cur.fetchone() == (7,) + await cur.scroll(2, mode="absolute") + assert await cur.fetchone() == (2,) + + # on the boundary + await cur.scroll(0, mode="absolute") + assert await cur.fetchone() == (0,) + with pytest.raises(IndexError): + await cur.scroll(-1, mode="absolute") + + await cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(-1) + + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + with pytest.raises(IndexError): + await cur.scroll(10, mode="absolute") + + await cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(1) + + with pytest.raises(ValueError): + await cur.scroll(1, "wat") + + +async def test_query_params_execute(aconn): + cur = aconn.cursor() + assert cur._query is None + + await cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select 1, NULL::text" + assert cur._query.params == (b"1", b"NULL") + + await cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + await cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select 'wat'::int" + assert cur._query.params == (b"'wat'",) + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select %(x)s", {"x": 1}, (1,)), + ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)), + ("select %(x)s, %(x)s", {"x": 1}, (1, 1)), + ], +) +async def test_query_params_named(aconn, query, params, want): + cur = aconn.cursor() + await cur.execute(query, params) + rec = await cur.fetchone() + assert rec == want + + +async def test_query_params_executemany(aconn): + cur = aconn.cursor() + + await cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select 3, 4" + assert cur._query.params == (b"3", b"4") + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +async def test_copy_out_param(aconn, ph, params): + cur = aconn.cursor() + async with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert await alist(copy.rows()) == [(i + 1,) for i in range(10)] + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_stream(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +async def test_str(aconn): + cur = aconn.cursor() + assert "psycopg.AsyncClientCursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + await cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + await cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_leak(aconn_cls, dsn, faker, fetch, row_factory): + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + async def work(): + async with await aconn_cls.connect(dsn) as conn, conn.transaction( + force_rollback=True + ): + async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + await cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = await cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = await cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + await cur.fetchall() + elif fetch == "iter": + async for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select 'hello'", (), "select 'hello'"), + ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"), + ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"), + ("select %%", (), "select %%"), + ("select %%, %s", (["a"],), "select %, 'a'"), + ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"), + ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"), + ], +) +async def test_mogrify(aconn, query, params, want): + cur = aconn.cursor() + got = cur.mogrify(query, *params) + assert got == want + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +async def test_mogrify_encoding(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + q = aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + assert q == "select '\u20ac'" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +async def test_mogrify_badenc(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + with pytest.raises(UnicodeEncodeError): + aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + + +@pytest.mark.pipeline +async def test_message_0x33(aconn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + await aconn.set_autocommit(True) + async with aconn.pipeline(): + cur = await aconn.execute("select 'test'") + assert (await cur.fetchone()) == ("test",) + + assert not notices diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..eec24f1 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,327 @@ +""" +Tests dealing with concurrency issues. +""" + +import os +import sys +import time +import queue +import signal +import threading +import multiprocessing +import subprocess as sp +from typing import List + +import pytest + +import psycopg +from psycopg import errors as e + + +@pytest.mark.slow +def test_concurrent_execution(conn_cls, dsn): + def worker(): + cnn = conn_cls.connect(dsn) + cur = cnn.cursor() + cur.execute("select pg_sleep(0.5)") + cur.close() + cnn.close() + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t0 = time.time() + t1.start() + t2.start() + t1.join() + t2.join() + assert time.time() - t0 < 0.8, "something broken in concurrency" + + +@pytest.mark.slow +def test_commit_concurrency(conn): + # Check the condition reported in psycopg2#103 + # Because of bad status check, we commit even when a commit is already on + # its way. We can detect this condition by the warnings. + notices = queue.Queue() # type: ignore[var-annotated] + conn.add_notice_handler(lambda diag: notices.put(diag.message_primary)) + stop = False + + def committer(): + nonlocal stop + while not stop: + conn.commit() + + cur = conn.cursor() + t1 = threading.Thread(target=committer) + t1.start() + for i in range(1000): + cur.execute("select %s;", (i,)) + conn.commit() + + # Stop the committer thread + stop = True + t1.join() + + assert notices.empty(), "%d notices raised" % notices.qsize() + + +@pytest.mark.slow +@pytest.mark.subprocess +def test_multiprocess_close(dsn, tmpdir): + # Check the problem reported in psycopg2#829 + # Subprocess gcs the copy of the fd after fork so it closes connection. + module = f"""\ +import time +import psycopg + +def thread(): + conn = psycopg.connect({dsn!r}) + curs = conn.cursor() + for i in range(10): + curs.execute("select 1") + time.sleep(0.1) + +def process(): + time.sleep(0.2) +""" + + script = """\ +import time +import threading +import multiprocessing +import mptest + +t = threading.Thread(target=mptest.thread, name='mythread') +t.start() +time.sleep(0.2) +multiprocessing.Process(target=mptest.process, name='myprocess').start() +t.join() +""" + + with (tmpdir / "mptest.py").open("w") as f: + f.write(module) + env = dict(os.environ) + env["PYTHONPATH"] = str(tmpdir + os.pathsep + env.get("PYTHONPATH", "")) + out = sp.check_output( + [sys.executable, "-c", script], stderr=sp.STDOUT, env=env + ).decode("utf8", "replace") + assert out == "", out.strip().splitlines()[-1] + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("notify") +def test_notifies(conn_cls, conn, dsn): + nconn = conn_cls.connect(dsn, autocommit=True) + npid = nconn.pgconn.backend_pid + + def notifier(): + time.sleep(0.25) + nconn.cursor().execute("notify foo, '1'") + time.sleep(0.25) + nconn.cursor().execute("notify foo, '2'") + nconn.close() + + conn.autocommit = True + conn.cursor().execute("listen foo") + + t0 = time.time() + t = threading.Thread(target=notifier) + t.start() + + ns = [] + gen = conn.notifies() + for n in gen: + ns.append((n, time.time())) + if len(ns) >= 2: + gen.close() + + assert len(ns) == 2 + + n, t1 = ns[0] + assert isinstance(n, psycopg.Notify) + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + t.join() + + +def canceller(conn, errors): + try: + time.sleep(0.5) + conn.cancel() + except Exception as exc: + errors.append(exc) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +def test_cancel(conn): + errors: List[Exception] = [] + + cur = conn.cursor() + t = threading.Thread(target=canceller, args=(conn, errors)) + t0 = time.time() + t.start() + + with pytest.raises(e.QueryCanceled): + cur.execute("select pg_sleep(2)") + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + conn.rollback() + assert cur.execute("select 1").fetchone()[0] == 1 + + t.join() + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +def test_cancel_stream(conn): + errors: List[Exception] = [] + + cur = conn.cursor() + t = threading.Thread(target=canceller, args=(conn, errors)) + t0 = time.time() + t.start() + + with pytest.raises(e.QueryCanceled): + for row in cur.stream("select pg_sleep(2)"): + pass + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + conn.rollback() + assert cur.execute("select 1").fetchone()[0] == 1 + + t.join() + + +@pytest.mark.crdb_skip("pg_terminate_backend") +@pytest.mark.slow +def test_identify_closure(conn_cls, dsn): + def closer(): + time.sleep(0.2) + conn2.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]) + + conn = conn_cls.connect(dsn) + conn2 = conn_cls.connect(dsn) + try: + t = threading.Thread(target=closer) + t.start() + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + conn.execute("select pg_sleep(1.0)") + t1 = time.time() + assert 0.2 < t1 - t0 < 0.4 + finally: + t.join() + finally: + conn.close() + conn2.close() + + +@pytest.mark.slow +@pytest.mark.subprocess +@pytest.mark.skipif( + sys.platform == "win32", reason="don't know how to Ctrl-C on Windows" +) +@pytest.mark.crdb_skip("cancel") +def test_ctrl_c(dsn): + if sys.platform == "win32": + sig = int(signal.CTRL_C_EVENT) + # Or pytest will receive the Ctrl-C too + creationflags = sp.CREATE_NEW_PROCESS_GROUP + else: + sig = int(signal.SIGINT) + creationflags = 0 + + script = f"""\ +import os +import time +import psycopg +from threading import Thread + +def tired_of_life(): + time.sleep(1) + os.kill(os.getpid(), {sig!r}) + +t = Thread(target=tired_of_life, daemon=True) +t.start() + +with psycopg.connect({dsn!r}) as conn: + cur = conn.cursor() + ctrl_c = False + try: + cur.execute("select pg_sleep(2)") + except KeyboardInterrupt: + ctrl_c = True + + assert ctrl_c, "ctrl-c not received" + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR + ), f"transaction status: {{conn.info.transaction_status!r}}" + + conn.rollback() + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE + ), f"transaction status: {{conn.info.transaction_status!r}}" + + cur.execute("select 1") + assert cur.fetchone() == (1,) +""" + t0 = time.time() + proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags) + proc.communicate() + t = time.time() - t0 + assert proc.returncode == 0 + assert 1 < t < 2 + + +@pytest.mark.slow +@pytest.mark.subprocess +@pytest.mark.skipif( + multiprocessing.get_all_start_methods()[0] != "fork", + reason="problematic behavior only exhibited via fork", +) +def test_segfault_on_fork_close(dsn): + # https://github.com/psycopg/psycopg/issues/300 + script = f"""\ +import gc +import psycopg +from multiprocessing import Pool + +def test(arg): + conn1 = psycopg.connect({dsn!r}) + conn1.close() + conn1 = None + gc.collect() + return 1 + +if __name__ == '__main__': + conn = psycopg.connect({dsn!r}) + with Pool(2) as p: + pool_result = p.map_async(test, [1, 2]) + pool_result.wait(timeout=5) + if pool_result.ready(): + print(pool_result.get(timeout=1)) +""" + env = dict(os.environ) + env["PYTHONFAULTHANDLER"] = "1" + out = sp.check_output([sys.executable, "-s", "-c", script], env=env) + assert out.decode().rstrip() == "[1, 1]" diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py new file mode 100644 index 0000000..29b08cf --- /dev/null +++ b/tests/test_concurrency_async.py @@ -0,0 +1,242 @@ +import sys +import time +import signal +import asyncio +import subprocess as sp +from asyncio.queues import Queue +from typing import List, Tuple + +import pytest + +import psycopg +from psycopg import errors as e +from psycopg._compat import create_task + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.slow +async def test_commit_concurrency(aconn): + # Check the condition reported in psycopg2#103 + # Because of bad status check, we commit even when a commit is already on + # its way. We can detect this condition by the warnings. + notices = Queue() # type: ignore[var-annotated] + aconn.add_notice_handler(lambda diag: notices.put_nowait(diag.message_primary)) + stop = False + + async def committer(): + nonlocal stop + while not stop: + await aconn.commit() + await asyncio.sleep(0) # Allow the other worker to work + + async def runner(): + nonlocal stop + cur = aconn.cursor() + for i in range(1000): + await cur.execute("select %s;", (i,)) + await aconn.commit() + + # Stop the committer thread + stop = True + + await asyncio.gather(committer(), runner()) + assert notices.empty(), "%d notices raised" % notices.qsize() + + +@pytest.mark.slow +async def test_concurrent_execution(aconn_cls, dsn): + async def worker(): + cnn = await aconn_cls.connect(dsn) + cur = cnn.cursor() + await cur.execute("select pg_sleep(0.5)") + await cur.close() + await cnn.close() + + workers = [worker(), worker()] + t0 = time.time() + await asyncio.gather(*workers) + assert time.time() - t0 < 0.8, "something broken in concurrency" + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("notify") +async def test_notifies(aconn_cls, aconn, dsn): + nconn = await aconn_cls.connect(dsn, autocommit=True) + npid = nconn.pgconn.backend_pid + + async def notifier(): + cur = nconn.cursor() + await asyncio.sleep(0.25) + await cur.execute("notify foo, '1'") + await asyncio.sleep(0.25) + await cur.execute("notify foo, '2'") + await nconn.close() + + async def receiver(): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("listen foo") + gen = aconn.notifies() + async for n in gen: + ns.append((n, time.time())) + if len(ns) >= 2: + await gen.aclose() + + ns: List[Tuple[psycopg.Notify, float]] = [] + t0 = time.time() + workers = [notifier(), receiver()] + await asyncio.gather(*workers) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +async def canceller(aconn, errors): + try: + await asyncio.sleep(0.5) + aconn.cancel() + except Exception as exc: + errors.append(exc) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +async def test_cancel(aconn): + async def worker(): + cur = aconn.cursor() + with pytest.raises(e.QueryCanceled): + await cur.execute("select pg_sleep(2)") + + errors: List[Exception] = [] + workers = [worker(), canceller(aconn, errors)] + + t0 = time.time() + await asyncio.gather(*workers) + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + await aconn.rollback() + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +async def test_cancel_stream(aconn): + async def worker(): + cur = aconn.cursor() + with pytest.raises(e.QueryCanceled): + async for row in cur.stream("select pg_sleep(2)"): + pass + + errors: List[Exception] = [] + workers = [worker(), canceller(aconn, errors)] + + t0 = time.time() + await asyncio.gather(*workers) + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + await aconn.rollback() + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_identify_closure(aconn_cls, dsn): + async def closer(): + await asyncio.sleep(0.2) + await conn2.execute( + "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid] + ) + + aconn = await aconn_cls.connect(dsn) + conn2 = await aconn_cls.connect(dsn) + try: + t = create_task(closer()) + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + await aconn.execute("select pg_sleep(1.0)") + t1 = time.time() + assert 0.2 < t1 - t0 < 0.4 + finally: + await asyncio.gather(t) + finally: + await aconn.close() + await conn2.close() + + +@pytest.mark.slow +@pytest.mark.subprocess +@pytest.mark.skipif( + sys.platform == "win32", reason="don't know how to Ctrl-C on Windows" +) +@pytest.mark.crdb_skip("cancel") +async def test_ctrl_c(dsn): + script = f"""\ +import signal +import asyncio +import psycopg + +async def main(): + ctrl_c = False + loop = asyncio.get_event_loop() + async with await psycopg.AsyncConnection.connect({dsn!r}) as conn: + loop.add_signal_handler(signal.SIGINT, conn.cancel) + cur = conn.cursor() + try: + await cur.execute("select pg_sleep(2)") + except psycopg.errors.QueryCanceled: + ctrl_c = True + + assert ctrl_c, "ctrl-c not received" + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR + ), f"transaction status: {{conn.info.transaction_status!r}}" + + await conn.rollback() + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE + ), f"transaction status: {{conn.info.transaction_status!r}}" + + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + +asyncio.run(main()) +""" + if sys.platform == "win32": + creationflags = sp.CREATE_NEW_PROCESS_GROUP + sig = signal.CTRL_C_EVENT + else: + creationflags = 0 + sig = signal.SIGINT + + proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags) + with pytest.raises(sp.TimeoutExpired): + outs, errs = proc.communicate(timeout=1) + + proc.send_signal(sig) + proc.communicate() + assert proc.returncode == 0 diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..57c6c78 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,790 @@ +import time +import pytest +import logging +import weakref +from typing import Any, List +from dataclasses import dataclass + +import psycopg +from psycopg import Notify, errors as e +from psycopg.rows import tuple_row +from psycopg.conninfo import conninfo_to_dict, make_conninfo + +from .utils import gc_collect +from .test_cursor import my_row_factory +from .test_adapt import make_bin_dumper, make_dumper + + +def test_connect(conn_cls, dsn): + conn = conn_cls.connect(dsn) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + conn.close() + + +def test_connect_str_subclass(conn_cls, dsn): + class MyString(str): + pass + + conn = conn_cls.connect(MyString(dsn)) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + conn.close() + + +def test_connect_bad(conn_cls): + with pytest.raises(psycopg.OperationalError): + conn_cls.connect("dbname=nosuchdb") + + +@pytest.mark.slow +@pytest.mark.timing +def test_connect_timeout(conn_cls, deaf_port): + t0 = time.time() + with pytest.raises(psycopg.OperationalError, match="timeout expired"): + conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) + elapsed = time.time() - t0 + assert elapsed == pytest.approx(1.0, abs=0.05) + + +def test_close(conn): + assert not conn.closed + assert not conn.broken + + cur = conn.cursor() + + conn.close() + assert conn.closed + assert not conn.broken + assert conn.pgconn.status == conn.ConnStatus.BAD + + conn.close() + assert conn.closed + assert conn.pgconn.status == conn.ConnStatus.BAD + + with pytest.raises(psycopg.OperationalError): + cur.execute("select 1") + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_broken(conn): + with pytest.raises(psycopg.OperationalError): + conn.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]) + assert conn.closed + assert conn.broken + conn.close() + assert conn.closed + assert conn.broken + + +def test_cursor_closed(conn): + conn.close() + with pytest.raises(psycopg.OperationalError): + with conn.cursor("foo"): + pass + with pytest.raises(psycopg.OperationalError): + conn.cursor() + + +def test_connection_warn_close(conn_cls, dsn, recwarn): + conn = conn_cls.connect(dsn) + conn.close() + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + conn = conn_cls.connect(dsn) + del conn + assert "IDLE" in str(recwarn.pop(ResourceWarning).message) + + conn = conn_cls.connect(dsn) + conn.execute("select 1") + del conn + assert "INTRANS" in str(recwarn.pop(ResourceWarning).message) + + conn = conn_cls.connect(dsn) + try: + conn.execute("select wat") + except Exception: + pass + del conn + assert "INERROR" in str(recwarn.pop(ResourceWarning).message) + + with conn_cls.connect(dsn) as conn: + pass + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.fixture +def testctx(svcconn): + svcconn.execute("create table if not exists testctx (id int primary key)") + svcconn.execute("delete from testctx") + return None + + +def test_context_commit(conn_cls, testctx, conn, dsn): + with conn: + with conn.cursor() as cur: + cur.execute("insert into testctx values (42)") + + assert conn.closed + assert not conn.broken + + with conn_cls.connect(dsn) as conn: + with conn.cursor() as cur: + cur.execute("select * from testctx") + assert cur.fetchall() == [(42,)] + + +def test_context_rollback(conn_cls, testctx, conn, dsn): + with pytest.raises(ZeroDivisionError): + with conn: + with conn.cursor() as cur: + cur.execute("insert into testctx values (42)") + 1 / 0 + + assert conn.closed + assert not conn.broken + + with conn_cls.connect(dsn) as conn: + with conn.cursor() as cur: + cur.execute("select * from testctx") + assert cur.fetchall() == [] + + +def test_context_close(conn): + with conn: + conn.execute("select 1") + conn.close() + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_context_inerror_rollback_no_clobber(conn_cls, conn, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + with conn_cls.connect(dsn) as conn2: + conn2.execute("select 1") + conn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + with conn_cls.connect(dsn) as conn: + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + assert not conn.pgconn.error_message + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.slow +def test_weakref(conn_cls, dsn): + conn = conn_cls.connect(dsn) + w = weakref.ref(conn) + conn.close() + del conn + gc_collect() + assert w() is None + + +def test_commit(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + conn.pgconn.exec_(b"begin") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + conn.pgconn.exec_(b"insert into foo values (1)") + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + res = conn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) == b"1" + + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.commit() + + +@pytest.mark.crdb_skip("deferrable") +def test_commit_error(conn): + conn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + conn.commit() + + conn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + +def test_rollback(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + conn.pgconn.exec_(b"begin") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + conn.pgconn.exec_(b"insert into foo values (1)") + conn.rollback() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + res = conn.pgconn.exec_(b"select id from foo where id = 1") + assert res.ntuples == 0 + + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.rollback() + + +def test_auto_transaction(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = conn.cursor() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + cur.execute("insert into foo values (1)") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + assert cur.execute("select * from foo").fetchone() == (1,) + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + +def test_auto_transaction_fail(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = conn.cursor() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + cur.execute("insert into foo values (1)") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + with pytest.raises(psycopg.DatabaseError): + cur.execute("meh") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + + with pytest.raises(psycopg.errors.InFailedSqlTransaction): + cur.execute("select 1") + + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + assert cur.execute("select * from foo").fetchone() is None + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + +def test_autocommit(conn): + assert conn.autocommit is False + conn.autocommit = True + assert conn.autocommit + cur = conn.cursor() + assert cur.execute("select 1").fetchone() == (1,) + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + conn.autocommit = "" + assert conn.autocommit is False # type: ignore[comparison-overlap] + conn.autocommit = "yeah" + assert conn.autocommit is True + + +def test_autocommit_connect(conn_cls, dsn): + conn = conn_cls.connect(dsn, autocommit=True) + assert conn.autocommit + conn.close() + + +def test_autocommit_intrans(conn): + cur = conn.cursor() + assert cur.execute("select 1").fetchone() == (1,) + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + with pytest.raises(psycopg.ProgrammingError): + conn.autocommit = True + assert not conn.autocommit + + +def test_autocommit_inerror(conn): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.execute("meh") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + with pytest.raises(psycopg.ProgrammingError): + conn.autocommit = True + assert not conn.autocommit + + +def test_autocommit_unknown(conn): + conn.close() + assert conn.pgconn.transaction_status == conn.TransactionStatus.UNKNOWN + with pytest.raises(psycopg.OperationalError): + conn.autocommit = True + assert not conn.autocommit + + +@pytest.mark.parametrize( + "args, kwargs, want", + [ + ((), {}, ""), + (("",), {}, ""), + (("host=foo user=bar",), {}, "host=foo user=bar"), + (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + ( + ("host=foo port=5432",), + {"host": "qux", "user": "joe"}, + "host=qux user=joe port=5432", + ), + (("host=foo",), {"user": None}, "host=foo"), + ], +) +def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want): + the_conninfo: str + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + conn = conn_cls.connect(*args, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + conn.close() + + +@pytest.mark.parametrize( + "args, kwargs, exctype", + [ + (("host=foo", "host=bar"), {}, TypeError), + (("", ""), {}, TypeError), + ((), {"nosuchparam": 42}, psycopg.ProgrammingError), + ], +) +def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + with pytest.raises(exctype): + conn_cls.connect(*args, **kwargs) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_broken_connection(conn): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.execute("select pg_terminate_backend(pg_backend_pid())") + assert conn.closed + + +@pytest.mark.crdb_skip("do") +def test_notice_handlers(conn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + messages = [] + severities = [] + + def cb1(diag): + messages.append(diag.message_primary) + + def cb2(res): + raise Exception("hello from cb2") + + conn.add_notice_handler(cb1) + conn.add_notice_handler(cb2) + conn.add_notice_handler("the wrong thing") + conn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized)) + + conn.pgconn.exec_(b"set client_min_messages to notice") + cur = conn.cursor() + cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql") + assert messages == ["hello notice"] + assert severities == ["NOTICE"] + + assert len(caplog.records) == 2 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello from cb2" in rec.message + rec = caplog.records[1] + assert rec.levelno == logging.ERROR + assert "the wrong thing" in rec.message + + conn.remove_notice_handler(cb1) + conn.remove_notice_handler("the wrong thing") + cur.execute("do $$begin raise warning 'hello warning'; end$$ language plpgsql") + assert len(caplog.records) == 3 + assert messages == ["hello notice"] + assert severities == ["NOTICE", "WARNING"] + + with pytest.raises(ValueError): + conn.remove_notice_handler(cb1) + + +@pytest.mark.crdb_skip("notify") +def test_notify_handlers(conn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + conn.add_notify_handler(cb1) + conn.add_notify_handler(lambda n: nots2.append(n)) + + conn.autocommit = True + cur = conn.cursor() + cur.execute("listen foo") + cur.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == conn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + conn.remove_notify_handler(cb1) + cur.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == conn.pgconn.backend_pid + assert hash(n) + + with pytest.raises(ValueError): + conn.remove_notify_handler(cb1) + + +def test_execute(conn): + cur = conn.execute("select %s, %s", [10, 20]) + assert cur.fetchone() == (10, 20) + assert cur.format == 0 + assert cur.pgresult.fformat(0) == 0 + + cur = conn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21}) + assert cur.fetchone() == (11, 21) + + cur = conn.execute("select 12, 22") + assert cur.fetchone() == (12, 22) + + +def test_execute_binary(conn): + cur = conn.execute("select %s, %s", [10, 20], binary=True) + assert cur.fetchone() == (10, 20) + assert cur.format == 1 + assert cur.pgresult.fformat(0) == 1 + + +def test_row_factory(conn_cls, dsn): + defaultconn = conn_cls.connect(dsn) + assert defaultconn.row_factory is tuple_row + defaultconn.close() + + conn = conn_cls.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory is my_row_factory + + cur = conn.execute("select 'a' as ve") + assert cur.fetchone() == ["Ave"] + + with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1: + cur1.execute("select 1, 1, 2") + assert cur1.fetchall() == [{1, 2}] + + with conn.cursor(row_factory=tuple_row) as cur2: + cur2.execute("select 1, 1, 2") + assert cur2.fetchall() == [(1, 1, 2)] + + # TODO: maybe fix something to get rid of 'type: ignore' below. + conn.row_factory = tuple_row + cur3 = conn.execute("select 'vale'") + r = cur3.fetchone() + assert r and r == ("vale",) + conn.close() + + +def test_str(conn): + assert "[IDLE]" in str(conn) + conn.close() + assert "[BAD]" in str(conn) + + +def test_fileno(conn): + assert conn.fileno() == conn.pgconn.socket + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.fileno() + + +def test_cursor_factory(conn): + assert conn.cursor_factory is psycopg.Cursor + + class MyCursor(psycopg.Cursor[psycopg.rows.Row]): + pass + + conn.cursor_factory = MyCursor + with conn.cursor() as cur: + assert isinstance(cur, MyCursor) + + with conn.execute("select 1") as cur: + assert isinstance(cur, MyCursor) + + +def test_cursor_factory_connect(conn_cls, dsn): + class MyCursor(psycopg.Cursor[psycopg.rows.Row]): + pass + + with conn_cls.connect(dsn, cursor_factory=MyCursor) as conn: + assert conn.cursor_factory is MyCursor + cur = conn.cursor() + assert type(cur) is MyCursor + + +def test_server_cursor_factory(conn): + assert conn.server_cursor_factory is psycopg.ServerCursor + + class MyServerCursor(psycopg.ServerCursor[psycopg.rows.Row]): + pass + + conn.server_cursor_factory = MyServerCursor + with conn.cursor(name="n") as cur: + assert isinstance(cur, MyServerCursor) + + +@dataclass +class ParamDef: + name: str + guc: str + values: List[Any] + + +param_isolation = ParamDef( + name="isolation_level", + guc="isolation", + values=list(psycopg.IsolationLevel), +) +param_read_only = ParamDef( + name="read_only", + guc="read_only", + values=[True, False], +) +param_deferrable = ParamDef( + name="deferrable", + guc="deferrable", + values=[True, False], +) + +# Map Python values to Postgres values for the tx_params possible values +tx_values_map = { + v.name.lower().replace("_", " "): v.value for v in psycopg.IsolationLevel +} +tx_values_map["on"] = True +tx_values_map["off"] = False + + +tx_params = [ + param_isolation, + param_read_only, + pytest.param(param_deferrable, marks=pytest.mark.crdb_skip("deferrable")), +] +tx_params_isolation = [ + pytest.param( + param_isolation, + id="isolation_level", + marks=pytest.mark.crdb("skip", reason="transaction isolation"), + ), + pytest.param( + param_read_only, id="read_only", marks=pytest.mark.crdb_skip("begin_read_only") + ), + pytest.param( + param_deferrable, id="deferrable", marks=pytest.mark.crdb_skip("deferrable") + ), +] + + +@pytest.mark.parametrize("param", tx_params) +def test_transaction_param_default(conn, param): + assert getattr(conn, param.name) is None + current, default = conn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ).fetchone() + assert current == default + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +def test_set_transaction_param_implicit(conn, param, autocommit): + conn.autocommit = autocommit + for value in param.values: + setattr(conn, param.name, value) + pgval, default = conn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ).fetchone() + if autocommit: + assert pgval == default + else: + assert tx_values_map[pgval] == value + conn.rollback() + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +def test_set_transaction_param_block(conn, param, autocommit): + conn.autocommit = autocommit + for value in param.values: + setattr(conn, param.name, value) + with conn.transaction(): + pgval = conn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ).fetchone()[0] + assert tx_values_map[pgval] == value + + +@pytest.mark.parametrize("param", tx_params) +def test_set_transaction_param_not_intrans_implicit(conn, param): + conn.execute("select 1") + with pytest.raises(psycopg.ProgrammingError): + setattr(conn, param.name, param.values[0]) + + +@pytest.mark.parametrize("param", tx_params) +def test_set_transaction_param_not_intrans_block(conn, param): + with conn.transaction(): + with pytest.raises(psycopg.ProgrammingError): + setattr(conn, param.name, param.values[0]) + + +@pytest.mark.parametrize("param", tx_params) +def test_set_transaction_param_not_intrans_external(conn, param): + conn.autocommit = True + conn.execute("begin") + with pytest.raises(psycopg.ProgrammingError): + setattr(conn, param.name, param.values[0]) + + +@pytest.mark.crdb("skip", reason="transaction isolation") +def test_set_transaction_param_all(conn): + params: List[Any] = tx_params[:] + params[2] = params[2].values[0] + + for param in params: + value = param.values[0] + setattr(conn, param.name, value) + + for param in params: + pgval = conn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ).fetchone()[0] + assert tx_values_map[pgval] == value + + +def test_set_transaction_param_strange(conn): + for val in ("asdf", 0, 5): + with pytest.raises(ValueError): + conn.isolation_level = val + + conn.isolation_level = psycopg.IsolationLevel.SERIALIZABLE.value + assert conn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE + + conn.read_only = 1 + assert conn.read_only is True + + conn.deferrable = 0 + assert conn.deferrable is False + + +conninfo_params_timeout = [ + ( + "", + {"dbname": "mydb", "connect_timeout": None}, + ({"dbname": "mydb"}, None), + ), + ( + "", + {"dbname": "mydb", "connect_timeout": 1}, + ({"dbname": "mydb", "connect_timeout": "1"}, 1), + ), + ( + "dbname=postgres", + {}, + ({"dbname": "postgres"}, None), + ), + ( + "dbname=postgres connect_timeout=2", + {}, + ({"dbname": "postgres", "connect_timeout": "2"}, 2), + ), + ( + "postgresql:///postgres?connect_timeout=2", + {"connect_timeout": 10}, + ({"dbname": "postgres", "connect_timeout": "10"}, 10), + ), +] + + +@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) +def test_get_connection_params(conn_cls, dsn, kwargs, exp): + params = conn_cls._get_connection_params(dsn, **kwargs) + conninfo = make_conninfo(**params) + assert conninfo_to_dict(conninfo) == exp[0] + assert params.get("connect_timeout") == exp[1] + + +def test_connect_context(conn_cls, dsn): + ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) + ctx.register_dumper(str, make_bin_dumper("b")) + ctx.register_dumper(str, make_dumper("t")) + + conn = conn_cls.connect(dsn, context=ctx) + + cur = conn.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellot" + cur = conn.execute("select %b", ["hello"]) + assert cur.fetchone()[0] == "hellob" + conn.close() + + +def test_connect_context_copy(conn_cls, dsn, conn): + conn.adapters.register_dumper(str, make_bin_dumper("b")) + conn.adapters.register_dumper(str, make_dumper("t")) + + conn2 = conn_cls.connect(dsn, context=conn) + + cur = conn2.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellot" + cur = conn2.execute("select %b", ["hello"]) + assert cur.fetchone()[0] == "hellob" + conn2.close() + + +def test_cancel_closed(conn): + conn.close() + conn.cancel() diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py new file mode 100644 index 0000000..1288a6c --- /dev/null +++ b/tests/test_connection_async.py @@ -0,0 +1,751 @@ +import time +import pytest +import logging +import weakref +from typing import List, Any + +import psycopg +from psycopg import Notify, errors as e +from psycopg.rows import tuple_row +from psycopg.conninfo import conninfo_to_dict, make_conninfo + +from .utils import gc_collect +from .test_cursor import my_row_factory +from .test_connection import tx_params, tx_params_isolation, tx_values_map +from .test_connection import conninfo_params_timeout +from .test_connection import testctx # noqa: F401 # fixture +from .test_adapt import make_bin_dumper, make_dumper +from .test_conninfo import fake_resolve # noqa: F401 + +pytestmark = pytest.mark.asyncio + + +async def test_connect(aconn_cls, dsn): + conn = await aconn_cls.connect(dsn) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + await conn.close() + + +async def test_connect_bad(aconn_cls): + with pytest.raises(psycopg.OperationalError): + await aconn_cls.connect("dbname=nosuchdb") + + +async def test_connect_str_subclass(aconn_cls, dsn): + class MyString(str): + pass + + conn = await aconn_cls.connect(MyString(dsn)) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + await conn.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_connect_timeout(aconn_cls, deaf_port): + t0 = time.time() + with pytest.raises(psycopg.OperationalError, match="timeout expired"): + await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) + elapsed = time.time() - t0 + assert elapsed == pytest.approx(1.0, abs=0.05) + + +async def test_close(aconn): + assert not aconn.closed + assert not aconn.broken + + cur = aconn.cursor() + + await aconn.close() + assert aconn.closed + assert not aconn.broken + assert aconn.pgconn.status == aconn.ConnStatus.BAD + + await aconn.close() + assert aconn.closed + assert aconn.pgconn.status == aconn.ConnStatus.BAD + + with pytest.raises(psycopg.OperationalError): + await cur.execute("select 1") + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_broken(aconn): + with pytest.raises(psycopg.OperationalError): + await aconn.execute( + "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid] + ) + assert aconn.closed + assert aconn.broken + await aconn.close() + assert aconn.closed + assert aconn.broken + + +async def test_cursor_closed(aconn): + await aconn.close() + with pytest.raises(psycopg.OperationalError): + async with aconn.cursor("foo"): + pass + aconn.cursor("foo") + with pytest.raises(psycopg.OperationalError): + aconn.cursor() + + +async def test_connection_warn_close(aconn_cls, dsn, recwarn): + conn = await aconn_cls.connect(dsn) + await conn.close() + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + conn = await aconn_cls.connect(dsn) + del conn + assert "IDLE" in str(recwarn.pop(ResourceWarning).message) + + conn = await aconn_cls.connect(dsn) + await conn.execute("select 1") + del conn + assert "INTRANS" in str(recwarn.pop(ResourceWarning).message) + + conn = await aconn_cls.connect(dsn) + try: + await conn.execute("select wat") + except Exception: + pass + del conn + assert "INERROR" in str(recwarn.pop(ResourceWarning).message) + + async with await aconn_cls.connect(dsn) as conn: + pass + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.mark.usefixtures("testctx") +async def test_context_commit(aconn_cls, aconn, dsn): + async with aconn: + async with aconn.cursor() as cur: + await cur.execute("insert into testctx values (42)") + + assert aconn.closed + assert not aconn.broken + + async with await aconn_cls.connect(dsn) as aconn: + async with aconn.cursor() as cur: + await cur.execute("select * from testctx") + assert await cur.fetchall() == [(42,)] + + +@pytest.mark.usefixtures("testctx") +async def test_context_rollback(aconn_cls, aconn, dsn): + with pytest.raises(ZeroDivisionError): + async with aconn: + async with aconn.cursor() as cur: + await cur.execute("insert into testctx values (42)") + 1 / 0 + + assert aconn.closed + assert not aconn.broken + + async with await aconn_cls.connect(dsn) as aconn: + async with aconn.cursor() as cur: + await cur.execute("select * from testctx") + assert await cur.fetchall() == [] + + +async def test_context_close(aconn): + async with aconn: + await aconn.execute("select 1") + await aconn.close() + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_context_inerror_rollback_no_clobber(aconn_cls, conn, dsn, caplog): + with pytest.raises(ZeroDivisionError): + async with await aconn_cls.connect(dsn) as conn2: + await conn2.execute("select 1") + conn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + async with await aconn_cls.connect(dsn) as conn: + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + assert not conn.pgconn.error_message + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.slow +async def test_weakref(aconn_cls, dsn): + conn = await aconn_cls.connect(dsn) + w = weakref.ref(conn) + await conn.close() + del conn + gc_collect() + assert w() is None + + +async def test_commit(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + aconn.pgconn.exec_(b"begin") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + aconn.pgconn.exec_(b"insert into foo values (1)") + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + res = aconn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) == b"1" + + await aconn.close() + with pytest.raises(psycopg.OperationalError): + await aconn.commit() + + +@pytest.mark.crdb_skip("deferrable") +async def test_commit_error(aconn): + await aconn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + await aconn.commit() + + await aconn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + cur = await aconn.execute("select 1") + assert await cur.fetchone() == (1,) + + +async def test_rollback(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + aconn.pgconn.exec_(b"begin") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + aconn.pgconn.exec_(b"insert into foo values (1)") + await aconn.rollback() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + res = aconn.pgconn.exec_(b"select id from foo where id = 1") + assert res.ntuples == 0 + + await aconn.close() + with pytest.raises(psycopg.OperationalError): + await aconn.rollback() + + +async def test_auto_transaction(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = aconn.cursor() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + await cur.execute("insert into foo values (1)") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + await cur.execute("select * from foo") + assert await cur.fetchone() == (1,) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_auto_transaction_fail(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = aconn.cursor() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + await cur.execute("insert into foo values (1)") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + with pytest.raises(psycopg.DatabaseError): + await cur.execute("meh") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + await cur.execute("select * from foo") + assert await cur.fetchone() is None + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_autocommit(aconn): + assert aconn.autocommit is False + with pytest.raises(AttributeError): + aconn.autocommit = True + assert not aconn.autocommit + + await aconn.set_autocommit(True) + assert aconn.autocommit + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.set_autocommit("") + assert aconn.autocommit is False + await aconn.set_autocommit("yeah") + assert aconn.autocommit is True + + +async def test_autocommit_connect(aconn_cls, dsn): + aconn = await aconn_cls.connect(dsn, autocommit=True) + assert aconn.autocommit + await aconn.close() + + +async def test_autocommit_intrans(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + with pytest.raises(psycopg.ProgrammingError): + await aconn.set_autocommit(True) + assert not aconn.autocommit + + +async def test_autocommit_inerror(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.execute("meh") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + with pytest.raises(psycopg.ProgrammingError): + await aconn.set_autocommit(True) + assert not aconn.autocommit + + +async def test_autocommit_unknown(aconn): + await aconn.close() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.UNKNOWN + with pytest.raises(psycopg.OperationalError): + await aconn.set_autocommit(True) + assert not aconn.autocommit + + +@pytest.mark.parametrize( + "args, kwargs, want", + [ + ((), {}, ""), + (("",), {}, ""), + (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), + (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), + ( + ("dbname=foo port=5432",), + {"dbname": "qux", "user": "joe"}, + "dbname=qux user=joe port=5432", + ), + (("dbname=foo",), {"user": None}, "dbname=foo"), + ], +) +async def test_connect_args( + aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want +): + the_conninfo: str + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + setpgenv({}) + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + conn = await aconn_cls.connect(*args, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + await conn.close() + + +@pytest.mark.parametrize( + "args, kwargs, exctype", + [ + (("host=foo", "host=bar"), {}, TypeError), + (("", ""), {}, TypeError), + ((), {"nosuchparam": 42}, psycopg.ProgrammingError), + ], +) +async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + with pytest.raises(exctype): + await aconn_cls.connect(*args, **kwargs) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_broken_connection(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.execute("select pg_terminate_backend(pg_backend_pid())") + assert aconn.closed + + +@pytest.mark.crdb_skip("do") +async def test_notice_handlers(aconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + messages = [] + severities = [] + + def cb1(diag): + messages.append(diag.message_primary) + + def cb2(res): + raise Exception("hello from cb2") + + aconn.add_notice_handler(cb1) + aconn.add_notice_handler(cb2) + aconn.add_notice_handler("the wrong thing") + aconn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized)) + + aconn.pgconn.exec_(b"set client_min_messages to notice") + cur = aconn.cursor() + await cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql") + assert messages == ["hello notice"] + assert severities == ["NOTICE"] + + assert len(caplog.records) == 2 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello from cb2" in rec.message + rec = caplog.records[1] + assert rec.levelno == logging.ERROR + assert "the wrong thing" in rec.message + + aconn.remove_notice_handler(cb1) + aconn.remove_notice_handler("the wrong thing") + await cur.execute( + "do $$begin raise warning 'hello warning'; end$$ language plpgsql" + ) + assert len(caplog.records) == 3 + assert messages == ["hello notice"] + assert severities == ["NOTICE", "WARNING"] + + with pytest.raises(ValueError): + aconn.remove_notice_handler(cb1) + + +@pytest.mark.crdb_skip("notify") +async def test_notify_handlers(aconn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + aconn.add_notify_handler(cb1) + aconn.add_notify_handler(lambda n: nots2.append(n)) + + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("listen foo") + await cur.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == aconn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + aconn.remove_notify_handler(cb1) + await cur.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == aconn.pgconn.backend_pid + + with pytest.raises(ValueError): + aconn.remove_notify_handler(cb1) + + +async def test_execute(aconn): + cur = await aconn.execute("select %s, %s", [10, 20]) + assert await cur.fetchone() == (10, 20) + assert cur.format == 0 + assert cur.pgresult.fformat(0) == 0 + + cur = await aconn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21}) + assert await cur.fetchone() == (11, 21) + + cur = await aconn.execute("select 12, 22") + assert await cur.fetchone() == (12, 22) + + +async def test_execute_binary(aconn): + cur = await aconn.execute("select %s, %s", [10, 20], binary=True) + assert await cur.fetchone() == (10, 20) + assert cur.format == 1 + assert cur.pgresult.fformat(0) == 1 + + +async def test_row_factory(aconn_cls, dsn): + defaultconn = await aconn_cls.connect(dsn) + assert defaultconn.row_factory is tuple_row + await defaultconn.close() + + conn = await aconn_cls.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory is my_row_factory + + cur = await conn.execute("select 'a' as ve") + assert await cur.fetchone() == ["Ave"] + + async with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1: + await cur1.execute("select 1, 1, 2") + assert await cur1.fetchall() == [{1, 2}] + + async with conn.cursor(row_factory=tuple_row) as cur2: + await cur2.execute("select 1, 1, 2") + assert await cur2.fetchall() == [(1, 1, 2)] + + # TODO: maybe fix something to get rid of 'type: ignore' below. + conn.row_factory = tuple_row + cur3 = await conn.execute("select 'vale'") + r = await cur3.fetchone() + assert r and r == ("vale",) + await conn.close() + + +async def test_str(aconn): + assert "[IDLE]" in str(aconn) + await aconn.close() + assert "[BAD]" in str(aconn) + + +async def test_fileno(aconn): + assert aconn.fileno() == aconn.pgconn.socket + await aconn.close() + with pytest.raises(psycopg.OperationalError): + aconn.fileno() + + +async def test_cursor_factory(aconn): + assert aconn.cursor_factory is psycopg.AsyncCursor + + class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]): + pass + + aconn.cursor_factory = MyCursor + async with aconn.cursor() as cur: + assert isinstance(cur, MyCursor) + + async with (await aconn.execute("select 1")) as cur: + assert isinstance(cur, MyCursor) + + +async def test_cursor_factory_connect(aconn_cls, dsn): + class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]): + pass + + async with await aconn_cls.connect(dsn, cursor_factory=MyCursor) as conn: + assert conn.cursor_factory is MyCursor + cur = conn.cursor() + assert type(cur) is MyCursor + + +async def test_server_cursor_factory(aconn): + assert aconn.server_cursor_factory is psycopg.AsyncServerCursor + + class MyServerCursor(psycopg.AsyncServerCursor[psycopg.rows.Row]): + pass + + aconn.server_cursor_factory = MyServerCursor + async with aconn.cursor(name="n") as cur: + assert isinstance(cur, MyServerCursor) + + +@pytest.mark.parametrize("param", tx_params) +async def test_transaction_param_default(aconn, param): + assert getattr(aconn, param.name) is None + cur = await aconn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ) + current, default = await cur.fetchone() + assert current == default + + +@pytest.mark.parametrize("param", tx_params) +async def test_transaction_param_readonly_property(aconn, param): + with pytest.raises(AttributeError): + setattr(aconn, param.name, None) + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +async def test_set_transaction_param_implicit(aconn, param, autocommit): + await aconn.set_autocommit(autocommit) + for value in param.values: + await getattr(aconn, f"set_{param.name}")(value) + cur = await aconn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ) + pgval, default = await cur.fetchone() + if autocommit: + assert pgval == default + else: + assert tx_values_map[pgval] == value + await aconn.rollback() + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +async def test_set_transaction_param_block(aconn, param, autocommit): + await aconn.set_autocommit(autocommit) + for value in param.values: + await getattr(aconn, f"set_{param.name}")(value) + async with aconn.transaction(): + cur = await aconn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ) + pgval = (await cur.fetchone())[0] + assert tx_values_map[pgval] == value + + +@pytest.mark.parametrize("param", tx_params) +async def test_set_transaction_param_not_intrans_implicit(aconn, param): + await aconn.execute("select 1") + value = param.values[0] + with pytest.raises(psycopg.ProgrammingError): + await getattr(aconn, f"set_{param.name}")(value) + + +@pytest.mark.parametrize("param", tx_params) +async def test_set_transaction_param_not_intrans_block(aconn, param): + value = param.values[0] + async with aconn.transaction(): + with pytest.raises(psycopg.ProgrammingError): + await getattr(aconn, f"set_{param.name}")(value) + + +@pytest.mark.parametrize("param", tx_params) +async def test_set_transaction_param_not_intrans_external(aconn, param): + value = param.values[0] + await aconn.set_autocommit(True) + await aconn.execute("begin") + with pytest.raises(psycopg.ProgrammingError): + await getattr(aconn, f"set_{param.name}")(value) + + +@pytest.mark.crdb("skip", reason="transaction isolation") +async def test_set_transaction_param_all(aconn): + params: List[Any] = tx_params[:] + params[2] = params[2].values[0] + + for param in params: + value = param.values[0] + await getattr(aconn, f"set_{param.name}")(value) + + for param in params: + cur = await aconn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ) + pgval = (await cur.fetchone())[0] + assert tx_values_map[pgval] == value + + +async def test_set_transaction_param_strange(aconn): + for val in ("asdf", 0, 5): + with pytest.raises(ValueError): + await aconn.set_isolation_level(val) + + await aconn.set_isolation_level(psycopg.IsolationLevel.SERIALIZABLE.value) + assert aconn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE + + await aconn.set_read_only(1) + assert aconn.read_only is True + + await aconn.set_deferrable(0) + assert aconn.deferrable is False + + +@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) +async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv): + setpgenv({}) + params = await aconn_cls._get_connection_params(dsn, **kwargs) + conninfo = make_conninfo(**params) + assert conninfo_to_dict(conninfo) == exp[0] + assert params["connect_timeout"] == exp[1] + + +async def test_connect_context_adapters(aconn_cls, dsn): + ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) + ctx.register_dumper(str, make_bin_dumper("b")) + ctx.register_dumper(str, make_dumper("t")) + + conn = await aconn_cls.connect(dsn, context=ctx) + + cur = await conn.execute("select %s", ["hello"]) + assert (await cur.fetchone())[0] == "hellot" + cur = await conn.execute("select %b", ["hello"]) + assert (await cur.fetchone())[0] == "hellob" + await conn.close() + + +async def test_connect_context_copy(aconn_cls, dsn, aconn): + aconn.adapters.register_dumper(str, make_bin_dumper("b")) + aconn.adapters.register_dumper(str, make_dumper("t")) + + aconn2 = await aconn_cls.connect(dsn, context=aconn) + + cur = await aconn2.execute("select %s", ["hello"]) + assert (await cur.fetchone())[0] == "hellot" + cur = await aconn2.execute("select %b", ["hello"]) + assert (await cur.fetchone())[0] == "hellob" + await aconn2.close() + + +async def test_cancel_closed(aconn): + await aconn.close() + aconn.cancel() + + +async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811 + got = [] + + def fake_connect_gen(conninfo, **kwargs): + got.append(conninfo) + 1 / 0 + + monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen) + + with pytest.raises(ZeroDivisionError): + await psycopg.AsyncConnection.connect("host=foo.com") + + assert len(got) == 1 + want = {"host": "foo.com", "hostaddr": "1.1.1.1"} + assert conninfo_to_dict(got[0]) == want diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py new file mode 100644 index 0000000..e2c2c01 --- /dev/null +++ b/tests/test_conninfo.py @@ -0,0 +1,450 @@ +import socket +import asyncio +import datetime as dt + +import pytest + +import psycopg +from psycopg import ProgrammingError +from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from psycopg.conninfo import resolve_hostaddr_async +from psycopg._encodings import pg2pyenc + +from .fix_crdb import crdb_encoding + +snowman = "\u2603" + + +class MyString(str): + pass + + +@pytest.mark.parametrize( + "conninfo, kwargs, exp", + [ + ("", {}, ""), + ("dbname=foo", {}, "dbname=foo"), + ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"), + ("dbname=sony", {"password": ""}, "dbname=sony password="), + ("dbname=foo", {"dbname": "bar"}, "dbname=bar"), + ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"), + ("", {"dbname": "foo"}, "dbname=foo"), + ("", {"dbname": "foo", "user": None}, "dbname=foo"), + ("", {"dbname": "foo", "port": 15432}, "dbname=foo port=15432"), + ("", {"dbname": "a'b"}, r"dbname='a\'b'"), + (f"dbname={snowman}", {}, f"dbname={snowman}"), + ("", {"dbname": snowman}, f"dbname={snowman}"), + ( + "postgresql://host1/test", + {"host": "host2"}, + "dbname=test host=host2", + ), + (MyString(""), {}, ""), + ], +) +def test_make_conninfo(conninfo, kwargs, exp): + out = make_conninfo(conninfo, **kwargs) + assert conninfo_to_dict(out) == conninfo_to_dict(exp) + + +@pytest.mark.parametrize( + "conninfo, kwargs", + [ + ("hello", {}), + ("dbname=foo bar", {}), + ("foo=bar", {}), + ("dbname=foo", {"bar": "baz"}), + ("postgresql://tester:secret@/test?port=5433=x", {}), + (f"{snowman}={snowman}", {}), + ], +) +def test_make_conninfo_bad(conninfo, kwargs): + with pytest.raises(ProgrammingError): + make_conninfo(conninfo, **kwargs) + + +@pytest.mark.parametrize( + "conninfo, exp", + [ + ("", {}), + ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}), + ("dbname=sony password=", {"dbname": "sony", "password": ""}), + ("dbname='foo bar'", {"dbname": "foo bar"}), + ("dbname='a\"b'", {"dbname": 'a"b'}), + (r"dbname='a\'b'", {"dbname": "a'b"}), + (r"dbname='a\\b'", {"dbname": r"a\b"}), + (f"dbname={snowman}", {"dbname": snowman}), + ( + "postgresql://tester:secret@/test?port=5433", + { + "user": "tester", + "password": "secret", + "dbname": "test", + "port": "5433", + }, + ), + ], +) +def test_conninfo_to_dict(conninfo, exp): + assert conninfo_to_dict(conninfo) == exp + + +def test_no_munging(): + dsnin = "dbname=a host=b user=c password=d" + dsnout = make_conninfo(dsnin) + assert dsnin == dsnout + + +class TestConnectionInfo: + @pytest.mark.parametrize( + "attr", + [("dbname", "db"), "host", "hostaddr", "user", "password", "options"], + ) + def test_attrs(self, conn, attr): + if isinstance(attr, tuple): + info_attr, pgconn_attr = attr + else: + info_attr = pgconn_attr = attr + + if info_attr == "hostaddr" and psycopg.pq.version() < 120000: + pytest.skip("hostaddr not supported on libpq < 12") + + info_val = getattr(conn.info, info_attr) + pgconn_val = getattr(conn.pgconn, pgconn_attr).decode() + assert info_val == pgconn_val + + conn.close() + with pytest.raises(psycopg.OperationalError): + getattr(conn.info, info_attr) + + @pytest.mark.libpq("< 12") + def test_hostaddr_not_supported(self, conn): + with pytest.raises(psycopg.NotSupportedError): + conn.info.hostaddr + + def test_port(self, conn): + assert conn.info.port == int(conn.pgconn.port.decode()) + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.info.port + + def test_get_params(self, conn, dsn): + info = conn.info.get_parameters() + for k, v in conninfo_to_dict(dsn).items(): + if k != "password": + assert info.get(k) == v + else: + assert k not in info + + def test_dsn(self, conn, dsn): + dsn = conn.info.dsn + assert "password" not in dsn + for k, v in conninfo_to_dict(dsn).items(): + if k != "password": + assert f"{k}=" in dsn + + def test_get_params_env(self, conn_cls, dsn, monkeypatch): + dsn = conninfo_to_dict(dsn) + dsn.pop("application_name", None) + + monkeypatch.delenv("PGAPPNAME", raising=False) + with conn_cls.connect(**dsn) as conn: + assert "application_name" not in conn.info.get_parameters() + + monkeypatch.setenv("PGAPPNAME", "hello test") + with conn_cls.connect(**dsn) as conn: + assert conn.info.get_parameters()["application_name"] == "hello test" + + def test_dsn_env(self, conn_cls, dsn, monkeypatch): + dsn = conninfo_to_dict(dsn) + dsn.pop("application_name", None) + + monkeypatch.delenv("PGAPPNAME", raising=False) + with conn_cls.connect(**dsn) as conn: + assert "application_name=" not in conn.info.dsn + + monkeypatch.setenv("PGAPPNAME", "hello test") + with conn_cls.connect(**dsn) as conn: + assert "application_name='hello test'" in conn.info.dsn + + def test_status(self, conn): + assert conn.info.status.name == "OK" + conn.close() + assert conn.info.status.name == "BAD" + + def test_transaction_status(self, conn): + assert conn.info.transaction_status.name == "IDLE" + conn.close() + assert conn.info.transaction_status.name == "UNKNOWN" + + @pytest.mark.pipeline + def test_pipeline_status(self, conn): + assert not conn.info.pipeline_status + assert conn.info.pipeline_status.name == "OFF" + with conn.pipeline(): + assert conn.info.pipeline_status + assert conn.info.pipeline_status.name == "ON" + + @pytest.mark.libpq("< 14") + def test_pipeline_status_no_pipeline(self, conn): + assert not conn.info.pipeline_status + assert conn.info.pipeline_status.name == "OFF" + + def test_no_password(self, dsn): + dsn2 = make_conninfo(dsn, password="the-pass-word") + pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode()) + info = ConnectionInfo(pgconn) + assert info.password == "the-pass-word" + assert "password" not in info.get_parameters() + assert info.get_parameters()["dbname"] == info.dbname + + def test_dsn_no_password(self, dsn): + dsn2 = make_conninfo(dsn, password="the-pass-word") + pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode()) + info = ConnectionInfo(pgconn) + assert info.password == "the-pass-word" + assert "password" not in info.dsn + assert f"dbname={info.dbname}" in info.dsn + + def test_parameter_status(self, conn): + assert conn.info.parameter_status("nosuchparam") is None + tz = conn.info.parameter_status("TimeZone") + assert tz and isinstance(tz, str) + assert tz == conn.execute("show timezone").fetchone()[0] + + @pytest.mark.crdb("skip") + def test_server_version(self, conn): + assert conn.info.server_version == conn.pgconn.server_version + + def test_error_message(self, conn): + assert conn.info.error_message == "" + with pytest.raises(psycopg.ProgrammingError) as ex: + conn.execute("wat") + + assert conn.info.error_message + assert str(ex.value) in conn.info.error_message + assert ex.value.diag.severity in conn.info.error_message + + conn.close() + assert "NULL" in conn.info.error_message + + @pytest.mark.crdb_skip("backend pid") + def test_backend_pid(self, conn): + assert conn.info.backend_pid + assert conn.info.backend_pid == conn.pgconn.backend_pid + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.info.backend_pid + + def test_timezone(self, conn): + conn.execute("set timezone to 'Europe/Rome'") + tz = conn.info.timezone + assert isinstance(tz, dt.tzinfo) + offset = tz.utcoffset(dt.datetime(2000, 1, 1)) + assert offset and offset.total_seconds() == 3600 + offset = tz.utcoffset(dt.datetime(2000, 7, 1)) + assert offset and offset.total_seconds() == 7200 + + @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones") + def test_timezone_warn(self, conn, caplog): + conn.execute("set timezone to 'FOOBAR0'") + assert len(caplog.records) == 0 + tz = conn.info.timezone + assert tz == dt.timezone.utc + assert len(caplog.records) == 1 + assert "FOOBAR0" in caplog.records[0].message + + conn.info.timezone + assert len(caplog.records) == 1 + + conn.execute("set timezone to 'FOOBAAR0'") + assert len(caplog.records) == 1 + conn.info.timezone + assert len(caplog.records) == 2 + assert "FOOBAAR0" in caplog.records[1].message + + def test_encoding(self, conn): + enc = conn.execute("show client_encoding").fetchone()[0] + assert conn.info.encoding == pg2pyenc(enc.encode()) + + @pytest.mark.crdb("skip", reason="encoding not normalized") + @pytest.mark.parametrize( + "enc, out, codec", + [ + ("utf8", "UTF8", "utf-8"), + ("utf-8", "UTF8", "utf-8"), + ("utf_8", "UTF8", "utf-8"), + ("eucjp", "EUC_JP", "euc_jp"), + ("euc-jp", "EUC_JP", "euc_jp"), + ("latin9", "LATIN9", "iso8859-15"), + ], + ) + def test_normalize_encoding(self, conn, enc, out, codec): + conn.execute("select set_config('client_encoding', %s, false)", [enc]) + assert conn.info.parameter_status("client_encoding") == out + assert conn.info.encoding == codec + + @pytest.mark.parametrize( + "enc, out, codec", + [ + ("utf8", "UTF8", "utf-8"), + ("utf-8", "UTF8", "utf-8"), + ("utf_8", "UTF8", "utf-8"), + crdb_encoding("eucjp", "EUC_JP", "euc_jp"), + crdb_encoding("euc-jp", "EUC_JP", "euc_jp"), + ], + ) + def test_encoding_env_var(self, conn_cls, dsn, monkeypatch, enc, out, codec): + monkeypatch.setenv("PGCLIENTENCODING", enc) + with conn_cls.connect(dsn) as conn: + clienc = conn.info.parameter_status("client_encoding") + assert clienc + if conn.info.vendor == "PostgreSQL": + assert clienc == out + else: + assert clienc.replace("-", "").replace("_", "").upper() == out + assert conn.info.encoding == codec + + @pytest.mark.crdb_skip("encoding") + def test_set_encoding_unsupported(self, conn): + cur = conn.cursor() + cur.execute("set client_encoding to EUC_TW") + with pytest.raises(psycopg.NotSupportedError): + cur.execute("select 'x'") + + def test_vendor(self, conn): + assert conn.info.vendor + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", "", None), + ("host='' user=bar", "host='' user=bar", None), + ( + "host=127.0.0.1 user=bar", + "host=127.0.0.1 user=bar hostaddr=127.0.0.1", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + "host=foo.com port=5432", + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_no_resolve( + setpgenv, conninfo, want, env, fail_resolve +): + setpgenv(env) + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=foo.com,qux.com port=5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433", + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433", + None, + ), + ( + "host=foo.com,nosuchhost.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com, port=5432,5433", + "host=foo.com, hostaddr=1.1.1.1, port=5432,5433", + None, + ), + ( + "host=nosuchhost.com,foo.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + {}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + await resolve_hostaddr_async(params) + + +@pytest.fixture +async def fake_resolve(monkeypatch): + fake_hosts = { + "localhost": "127.0.0.1", + "foo.com": "1.1.1.1", + "qux.com": "2.2.2.2", + } + + async def fake_getaddrinfo(host, port, **kwargs): + assert isinstance(port, int) or (isinstance(port, str) and port.isdigit()) + try: + addr = fake_hosts[host] + except KeyError: + raise OSError(f"unknown test host: {host}") + else: + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))] + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo) + + +@pytest.fixture +async def fail_resolve(monkeypatch): + async def fail_getaddrinfo(host, port, **kwargs): + pytest.fail(f"shouldn't try to resolve {host}") + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo) diff --git a/tests/test_copy.py b/tests/test_copy.py new file mode 100644 index 0000000..17cf2fc --- /dev/null +++ b/tests/test_copy.py @@ -0,0 +1,889 @@ +import string +import struct +import hashlib +from io import BytesIO, StringIO +from random import choice, randrange +from itertools import cycle + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg import errors as e +from psycopg.pq import Format +from psycopg.copy import Copy, LibpqWriter, QueuedLibpqDriver, FileWriter +from psycopg.adapt import PyFormat +from psycopg.types import TypeInfo +from psycopg.types.hstore import register_hstore +from psycopg.types.numeric import Int4 + +from .utils import eur, gc_collect, gc_count + +pytestmark = pytest.mark.crdb_skip("copy") + +sample_records = [(40010, 40020, "hello"), (40040, None, "world")] +sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')" +sample_tabledef = "col1 serial primary key, col2 int, data text" + +sample_text = b"""\ +40010\t40020\thello +40040\t\\N\tworld +""" + +sample_binary_str = """ +5047 434f 5059 0aff 0d0a 00 +00 0000 0000 0000 00 +00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f + +0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64 + +ff ff +""" + +sample_binary_rows = [ + bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n") +] +sample_binary = b"".join(sample_binary_rows) + +special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"} + + +@pytest.mark.parametrize("format", Format) +def test_copy_out_read(conn, format): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + for row in want: + got = copy.read() + assert got == row + assert conn.info.transaction_status == conn.TransactionStatus.ACTIVE + + assert copy.read() == b"" + assert copy.read() == b"" + + assert copy.read() == b"" + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_copy_out_iter(conn, format, row_factory): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + rf = getattr(psycopg.rows, row_factory) + cur = conn.cursor(row_factory=rf) + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + assert list(copy) == want + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_copy_out_no_result(conn, format, row_factory): + rf = getattr(psycopg.rows, row_factory) + cur = conn.cursor(row_factory=rf) + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"): + with pytest.raises(e.ProgrammingError): + cur.fetchone() + + +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +def test_copy_out_param(conn, ph, params): + cur = conn.cursor() + with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert list(copy.rows()) == [(i + 1,) for i in range(10)] + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("typetype", ["names", "oids"]) +def test_read_rows(conn, format, typetype): + cur = conn.cursor() + with cur.copy( + f"""copy ( + select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] + ) to stdout (format {format.name})""" + ) as copy: + copy.set_types(["int4", "text", "float8[]"]) + row = copy.read_row() + assert copy.read_row() is None + + assert row == (10, "hello", [0.0, 1.0]) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +def test_rows(conn, format): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) + rows = list(copy.rows()) + + assert rows == sample_records + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_set_custom_type(conn, hstore): + command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout""" + cur = conn.cursor() + + with cur.copy(command) as copy: + rows = list(copy.rows()) + + assert rows == [('"a"=>"1", "b"=>"2"',)] + + register_hstore(TypeInfo.fetch(conn, "hstore"), cur) + with cur.copy(command) as copy: + copy.set_types(["hstore"]) + rows = list(copy.rows()) + + assert rows == [({"a": "1", "b": "2"},)] + + +@pytest.mark.parametrize("format", Format) +def test_copy_out_allchars(conn, format): + cur = conn.cursor() + chars = list(map(chr, range(1, 256))) + [eur] + conn.execute("set client_encoding to utf8") + rows = [] + query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format( + chars, sql.SQL(format.name) + ) + with cur.copy(query) as copy: + copy.set_types(["text"]) + while True: + row = copy.read_row() + if not row: + break + assert len(row) == 1 + rows.append(row[0]) + + assert rows == chars + + +@pytest.mark.parametrize("format", Format) +def test_read_row_notypes(conn, format): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + rows = [] + while True: + row = copy.read_row() + if not row: + break + rows.append(row) + + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("format", Format) +def test_rows_notypes(conn, format): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + rows = list(copy.rows()) + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("err", [-1, 1]) +@pytest.mark.parametrize("format", Format) +def test_copy_out_badntypes(conn, format, err): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + copy.set_types([0] * (len(sample_records[0]) + err)) + with pytest.raises(e.ProgrammingError): + copy.read_row() + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_copy_in_buffers(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_buffers_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + copy.write(sample_text) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_bad_result(conn): + conn.autocommit = True + + cur = conn.cursor() + + with pytest.raises(e.SyntaxError): + with cur.copy("wat"): + pass + + with pytest.raises(e.ProgrammingError): + with cur.copy("select 1"): + pass + + with pytest.raises(e.ProgrammingError): + with cur.copy("reset timezone"): + pass + + with pytest.raises(e.ProgrammingError): + with cur.copy("copy (select 1) to stdout; select 1") as copy: + list(copy) + + with pytest.raises(e.ProgrammingError): + with cur.copy("select 1; copy (select 1) to stdout"): + pass + + +def test_copy_in_str(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text.decode()) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + with cur.copy("copy copy_in from stdin (format binary)") as copy: + copy.write(sample_text.decode()) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin (format {format.name})"): + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +def test_copy_big_size_record(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write_row([data]) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) +def test_copy_big_size_block(conn, pytype): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write(copy_data) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.parametrize("format", Format) +def test_subclass_adapter(conn, format): + if format == Format.TEXT: + from psycopg.types.string import StrDumper as BaseDumper + else: + from psycopg.types.string import ( # type: ignore[no-redef] + StrBinaryDumper as BaseDumper, + ) + + class MyStrDumper(BaseDumper): + def dump(self, obj): + return super().dump(obj) * 2 + + conn.adapters.register_dumper(str, MyStrDumper) + + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy: + copy.write_row(("hello",)) + + rec = cur.execute("select data from copy_in").fetchone() + assert rec[0] == "hellohello" + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_error_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy(f"copy copy_in from stdin (format {format.name})"): + raise Exception("mannaggiamiseria") + + assert "mannaggiamiseria" in str(exc.value) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_buffers_with_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + copy.write(sample_text) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_buffers_with_py_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_out_error_with_copy_finished(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy: + copy.read_row() + 1 / 0 + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_copy_out_error_with_copy_not_finished(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with cur.copy("copy (select generate_series(1, 1000000)) to stdout") as copy: + copy.read_row() + 1 / 0 + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_out_server_error(conn): + cur = conn.cursor() + with pytest.raises(e.DivisionByZero): + with cur.copy( + "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout" + ) as copy: + for block in copy: + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_set_types(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_binary(conn, format): + cur = conn.cursor() + ensure_table(cur, "col1 serial primary key, col2 int, data text") + + with cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" + ) as copy: + for row in sample_records: + copy.write_row((None, row[2])) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == [(1, None, "hello"), (2, None, "world")] + + +def test_copy_in_allchars(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + conn.execute("set client_encoding to utf8") + with cur.copy("copy copy_in from stdin (format text)") as copy: + for i in range(1, 256): + copy.write_row((i, None, chr(i))) + copy.write_row((ord(eur), None, eur)) + + data = cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ).fetchall() + assert data == [(True, True, 1, 256)] + + +def test_copy_in_format(conn): + file = BytesIO() + conn.execute("set client_encoding to utf8") + cur = conn.cursor() + with Copy(cur, writer=FileWriter(file)) as copy: + for i in range(1, 256): + copy.write_row((i, chr(i))) + + file.seek(0) + rows = file.read().split(b"\n") + assert not rows[-1] + del rows[-1] + + for i, row in enumerate(rows, start=1): + fields = row.split(b"\t") + assert len(fields) == 2 + assert int(fields[0].decode()) == i + if i in special_chars: + assert fields[1].decode() == f"\\{special_chars[i]}" + else: + assert fields[1].decode() == chr(i) + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +def test_file_writer(conn, format, buffer): + file = BytesIO() + conn.execute("set client_encoding to utf8") + cur = conn.cursor() + with Copy(cur, binary=format, writer=FileWriter(file)) as copy: + for record in sample_records: + copy.write_row(record) + + file.seek(0) + want = globals()[buffer] + got = file.read() + assert got == want + + +@pytest.mark.slow +def test_copy_from_to(conn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + + gen.assert_data() + + f = BytesIO() + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview]) +def test_copy_from_to_bytes(conn, pytype): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(pytype(block.encode())) + + gen.assert_data() + + f = BytesIO() + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +def test_copy_from_insane_size(conn): + # Trying to trigger a "would block" error + gen = DataGenerator( + conn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024 + ) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + + gen.assert_data() + + +def test_copy_rowcount(conn): + gen = DataGenerator(conn, nrecs=3, srec=10) + gen.ensure_table() + + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + assert cur.rowcount == 3 + + gen = DataGenerator(conn, nrecs=2, srec=10, offset=3) + with cur.copy("copy copy_in from stdin") as copy: + for rec in gen.records(): + copy.write_row(rec) + assert cur.rowcount == 2 + + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + pass + assert cur.rowcount == 5 + + with pytest.raises(e.BadCopyFileFormat): + with cur.copy("copy copy_in (id) from stdin") as copy: + for rec in gen.records(): + copy.write_row(rec) + assert cur.rowcount == -1 + + +def test_copy_query(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + assert cur._query.query == b"copy (select 1) to stdout" + assert not cur._query.params + list(copy) + + +def test_cant_reenter(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + list(copy) + + with pytest.raises(TypeError): + with copy: + list(copy) + + +def test_str(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + assert "[ACTIVE]" in str(copy) + list(copy) + + assert "[INTRANS]" in str(copy) + + +def test_description(conn): + with conn.cursor() as cur: + with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy: + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + list(copy.rows()) + + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +def test_worker_life(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=QueuedLibpqDriver(cur) + ) as copy: + assert not copy.writer._worker + copy.write(globals()[buffer]) + assert copy.writer._worker + + assert not copy.writer._worker + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_worker_error_propagated(conn, monkeypatch): + def copy_to_broken(pgconn, buffer): + raise ZeroDivisionError + yield + + monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + cur = conn.cursor() + cur.execute("create temp table wat (a text, b text)") + with pytest.raises(ZeroDivisionError): + with cur.copy("copy wat from stdin", writer=QueuedLibpqDriver(cur)) as copy: + copy.write("a,b") + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +def test_connection_writer(conn, format, buffer): + cur = conn.cursor() + writer = LibpqWriter(cur) + + ensure_table(cur, sample_tabledef) + with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=writer + ) as copy: + assert copy.writer is writer + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) +def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + def work(): + with conn_cls.connect(dsn) as conn: + with conn.cursor(binary=fmt) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + stmt = sql.SQL( + "copy (select {} from {} order by id) to stdout (format {})" + ).format( + sql.SQL(", ").join(faker.fields_names), + faker.table_name, + sql.SQL(fmt.name), + ) + + with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + + if method == "read": + while True: + tmp = copy.read() + if not tmp: + break + elif method == "iter": + list(copy) + elif method == "row": + while True: + tmp = copy.read_row() + if tmp is None: + break + elif method == "rows": + list(copy.rows()) + + gc_collect() + n = [] + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + def work(): + with conn_cls.connect(dsn) as conn: + with conn.cursor(binary=fmt) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin (format {})").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL(fmt.name), + ) + with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + copy.write_row(row) + + cur.execute(faker.select_stmt) + recs = cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["row", "block", "binary"]) +def test_copy_table_across(conn_cls, dsn, faker, mode): + faker.choose_schema(ncols=20) + faker.make_records(20) + + with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2: + faker.table_name = sql.Identifier("copy_src") + conn1.execute(faker.drop_stmt) + conn1.execute(faker.create_stmt) + conn1.cursor().executemany(faker.insert_stmt, faker.records) + + faker.table_name = sql.Identifier("copy_tgt") + conn2.execute(faker.drop_stmt) + conn2.execute(faker.create_stmt) + + fmt = "(format binary)" if mode == "binary" else "" + with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1: + with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2: + if mode == "row": + for row in copy1.rows(): + copy2.write_row(row) + else: + for data in copy1: + copy2.write(data) + + recs = conn2.execute(faker.select_stmt).fetchall() + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + +def py_to_raw(item, fmt): + """Convert from Python type to the expected result from the db""" + if fmt == Format.TEXT: + if isinstance(item, int): + return str(item) + else: + if isinstance(item, int): + # Assume int4 + return struct.pack("!i", item) + elif isinstance(item, str): + return item.encode() + return item + + +def ensure_table(cur, tabledef, name="copy_in"): + cur.execute(f"drop table if exists {name}") + cur.execute(f"create table {name} ({tabledef})") + + +class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): + self.conn = conn + self.nrecs = nrecs + self.srec = srec + self.offset = offset + self.block_size = block_size + + def ensure_table(self): + cur = self.conn.cursor() + ensure_table(cur, "id integer primary key, data text") + + def records(self): + for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): + s = c * self.srec + yield (i + self.offset, s) + + def file(self): + f = StringIO() + for i, s in self.records(): + f.write("%s\t%s\n" % (i, s)) + + f.seek(0) + return f + + def blocks(self): + f = self.file() + while True: + block = f.read(self.block_size) + if not block: + break + yield block + + def assert_data(self): + cur = self.conn.cursor() + cur.execute("select id, data from copy_in order by id") + for record in self.records(): + assert record == cur.fetchone() + + assert cur.fetchone() is None + + def sha(self, f): + m = hashlib.sha256() + while True: + block = f.read() + if not block: + break + if isinstance(block, str): + block = block.encode() + m.update(block) + return m.hexdigest() diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py new file mode 100644 index 0000000..59389dd --- /dev/null +++ b/tests/test_copy_async.py @@ -0,0 +1,892 @@ +import string +import hashlib +from io import BytesIO, StringIO +from random import choice, randrange +from itertools import cycle + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg import errors as e +from psycopg.pq import Format +from psycopg.copy import AsyncCopy +from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter +from psycopg.types import TypeInfo +from psycopg.adapt import PyFormat +from psycopg.types.hstore import register_hstore +from psycopg.types.numeric import Int4 + +from .utils import alist, eur, gc_collect, gc_count +from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa +from .test_copy import sample_values, sample_records, sample_tabledef +from .test_copy import py_to_raw, special_chars + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("copy"), +] + + +@pytest.mark.parametrize("format", Format) +async def test_copy_out_read(aconn, format): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + for row in want: + got = await copy.read() + assert got == row + assert aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE + + assert await copy.read() == b"" + assert await copy.read() == b"" + + assert await copy.read() == b"" + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_copy_out_iter(aconn, format, row_factory): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + rf = getattr(psycopg.rows, row_factory) + cur = aconn.cursor(row_factory=rf) + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + assert await alist(copy) == want + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_copy_out_no_result(aconn, format, row_factory): + rf = getattr(psycopg.rows, row_factory) + cur = aconn.cursor(row_factory=rf) + async with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"): + with pytest.raises(e.ProgrammingError): + await cur.fetchone() + + +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +async def test_copy_out_param(aconn, ph, params): + cur = aconn.cursor() + async with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert await alist(copy.rows()) == [(i + 1,) for i in range(10)] + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("typetype", ["names", "oids"]) +async def test_read_rows(aconn, format, typetype): + cur = aconn.cursor() + async with cur.copy( + f"""copy ( + select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] + ) to stdout (format {format.name})""" + ) as copy: + copy.set_types(["int4", "text", "float8[]"]) + row = await copy.read_row() + assert (await copy.read_row()) is None + + assert row == (10, "hello", [0.0, 1.0]) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +async def test_rows(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types("int4 int4 text".split()) + rows = await alist(copy.rows()) + + assert rows == sample_records + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_set_custom_type(aconn, hstore): + command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout""" + cur = aconn.cursor() + + async with cur.copy(command) as copy: + rows = await alist(copy.rows()) + + assert rows == [('"a"=>"1", "b"=>"2"',)] + + register_hstore(await TypeInfo.fetch(aconn, "hstore"), cur) + async with cur.copy(command) as copy: + copy.set_types(["hstore"]) + rows = await alist(copy.rows()) + + assert rows == [({"a": "1", "b": "2"},)] + + +@pytest.mark.parametrize("format", Format) +async def test_copy_out_allchars(aconn, format): + cur = aconn.cursor() + chars = list(map(chr, range(1, 256))) + [eur] + await aconn.execute("set client_encoding to utf8") + rows = [] + query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format( + chars, sql.SQL(format.name) + ) + async with cur.copy(query) as copy: + copy.set_types(["text"]) + while True: + row = await copy.read_row() + if not row: + break + assert len(row) == 1 + rows.append(row[0]) + + assert rows == chars + + +@pytest.mark.parametrize("format", Format) +async def test_read_row_notypes(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = [] + while True: + row = await copy.read_row() + if not row: + break + rows.append(row) + + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("format", Format) +async def test_rows_notypes(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = await alist(copy.rows()) + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("err", [-1, 1]) +@pytest.mark.parametrize("format", Format) +async def test_copy_out_badntypes(aconn, format, err): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types([0] * (len(sample_records[0]) + err)) + with pytest.raises(e.ProgrammingError): + await copy.read_row() + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_copy_in_buffers(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_buffers_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_bad_result(aconn): + await aconn.set_autocommit(True) + + cur = aconn.cursor() + + with pytest.raises(e.SyntaxError): + async with cur.copy("wat"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("select 1"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("reset timezone"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("copy (select 1) to stdout; select 1") as copy: + await alist(copy) + + with pytest.raises(e.ProgrammingError): + async with cur.copy("select 1; copy (select 1) to stdout"): + pass + + +async def test_copy_in_str(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text.decode()) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + async with cur.copy("copy copy_in from stdin (format binary)") as copy: + await copy.write(sample_text.decode()) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +async def test_copy_big_size_record(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write_row([data]) + + await cur.execute("select data from copy_in limit 1") + assert await cur.fetchone() == (data,) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) +async def test_copy_big_size_block(aconn, pytype): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write(copy_data) + + await cur.execute("select data from copy_in limit 1") + assert await cur.fetchone() == (data,) + + +@pytest.mark.parametrize("format", Format) +async def test_subclass_adapter(aconn, format): + if format == Format.TEXT: + from psycopg.types.string import StrDumper as BaseDumper + else: + from psycopg.types.string import ( # type: ignore[no-redef] + StrBinaryDumper as BaseDumper, + ) + + class MyStrDumper(BaseDumper): + def dump(self, obj): + return super().dump(obj) * 2 + + aconn.adapters.register_dumper(str, MyStrDumper) + + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy( + f"copy copy_in (data) from stdin (format {format.name})" + ) as copy: + await copy.write_row(("hello",)) + + await cur.execute("select data from copy_in") + rec = await cur.fetchone() + assert rec[0] == "hellohello" + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_error_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + raise Exception("mannaggiamiseria") + + assert "mannaggiamiseria" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_buffers_with_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_buffers_with_py_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_out_error_with_copy_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_copy_out_error_with_copy_not_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy (select generate_series(1, 1000000)) to stdout" + ) as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_out_server_error(aconn): + cur = aconn.cursor() + with pytest.raises(e.DivisionByZero): + async with cur.copy( + "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout" + ) as copy: + async for block in copy: + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_set_types(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_binary(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, "col1 serial primary key, col2 int, data text") + + async with cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" + ) as copy: + for row in sample_records: + await copy.write_row((None, row[2])) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == [(1, None, "hello"), (2, None, "world")] + + +async def test_copy_in_allchars(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + await aconn.execute("set client_encoding to utf8") + async with cur.copy("copy copy_in from stdin (format text)") as copy: + for i in range(1, 256): + await copy.write_row((i, None, chr(i))) + await copy.write_row((ord(eur), None, eur)) + + await cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ) + data = await cur.fetchall() + assert data == [(True, True, 1, 256)] + + +async def test_copy_in_format(aconn): + file = BytesIO() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy: + for i in range(1, 256): + await copy.write_row((i, chr(i))) + + file.seek(0) + rows = file.read().split(b"\n") + assert not rows[-1] + del rows[-1] + + for i, row in enumerate(rows, start=1): + fields = row.split(b"\t") + assert len(fields) == 2 + assert int(fields[0].decode()) == i + if i in special_chars: + assert fields[1].decode() == f"\\{special_chars[i]}" + else: + assert fields[1].decode() == chr(i) + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_file_writer(aconn, format, buffer): + file = BytesIO() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy: + for record in sample_records: + await copy.write_row(record) + + file.seek(0) + want = globals()[buffer] + got = file.read() + assert got == want + + +@pytest.mark.slow +async def test_copy_from_to(aconn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + f = BytesIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview]) +async def test_copy_from_to_bytes(aconn, pytype): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(pytype(block.encode())) + + await gen.assert_data() + + f = BytesIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +async def test_copy_from_insane_size(aconn): + # Trying to trigger a "would block" error + gen = DataGenerator( + aconn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024 + ) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + +async def test_copy_rowcount(aconn): + gen = DataGenerator(aconn, nrecs=3, srec=10) + await gen.ensure_table() + + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + assert cur.rowcount == 3 + + gen = DataGenerator(aconn, nrecs=2, srec=10, offset=3) + async with cur.copy("copy copy_in from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == 2 + + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + pass + assert cur.rowcount == 5 + + with pytest.raises(e.BadCopyFileFormat): + async with cur.copy("copy copy_in (id) from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == -1 + + +async def test_copy_query(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + assert cur._query.query == b"copy (select 1) to stdout" + assert not cur._query.params + await alist(copy) + + +async def test_cant_reenter(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + await alist(copy) + + with pytest.raises(TypeError): + async with copy: + await alist(copy) + + +async def test_str(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + assert "[ACTIVE]" in str(copy) + await alist(copy) + + assert "[INTRANS]" in str(copy) + + +async def test_description(aconn): + async with aconn.cursor() as cur: + async with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy: + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + await alist(copy.rows()) + + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_worker_life(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})", + writer=AsyncQueuedLibpqWriter(cur), + ) as copy: + assert not copy.writer._worker + await copy.write(globals()[buffer]) + assert copy.writer._worker + + assert not copy.writer._worker + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_worker_error_propagated(aconn, monkeypatch): + def copy_to_broken(pgconn, buffer): + raise ZeroDivisionError + yield + + monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + cur = aconn.cursor() + await cur.execute("create temp table wat (a text, b text)") + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy wat from stdin", writer=AsyncQueuedLibpqWriter(cur) + ) as copy: + await copy.write("a,b") + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_connection_writer(aconn, format, buffer): + cur = aconn.cursor() + writer = AsyncLibpqWriter(cur) + + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=writer + ) as copy: + assert copy.writer is writer + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) +async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + + stmt = sql.SQL( + "copy (select {} from {} order by id) to stdout (format {})" + ).format( + sql.SQL(", ").join(faker.fields_names), + faker.table_name, + sql.SQL(fmt.name), + ) + + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + + if method == "read": + while True: + tmp = await copy.read() + if not tmp: + break + elif method == "iter": + await alist(copy) + elif method == "row": + while True: + tmp = await copy.read_row() + if tmp is None: + break + elif method == "rows": + await alist(copy.rows()) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin (format {})").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL(fmt.name), + ) + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + await copy.write_row(row) + + await cur.execute(faker.select_stmt) + recs = await cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["row", "block", "binary"]) +async def test_copy_table_across(aconn_cls, dsn, faker, mode): + faker.choose_schema(ncols=20) + faker.make_records(20) + + connect = aconn_cls.connect + async with await connect(dsn) as conn1, await connect(dsn) as conn2: + faker.table_name = sql.Identifier("copy_src") + await conn1.execute(faker.drop_stmt) + await conn1.execute(faker.create_stmt) + await conn1.cursor().executemany(faker.insert_stmt, faker.records) + + faker.table_name = sql.Identifier("copy_tgt") + await conn2.execute(faker.drop_stmt) + await conn2.execute(faker.create_stmt) + + fmt = "(format binary)" if mode == "binary" else "" + async with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1: + async with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2: + if mode == "row": + async for row in copy1.rows(): + await copy2.write_row(row) + else: + async for data in copy1: + await copy2.write(data) + + cur = await conn2.execute(faker.select_stmt) + recs = await cur.fetchall() + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + +async def ensure_table(cur, tabledef, name="copy_in"): + await cur.execute(f"drop table if exists {name}") + await cur.execute(f"create table {name} ({tabledef})") + + +class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): + self.conn = conn + self.nrecs = nrecs + self.srec = srec + self.offset = offset + self.block_size = block_size + + async def ensure_table(self): + cur = self.conn.cursor() + await ensure_table(cur, "id integer primary key, data text") + + def records(self): + for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): + s = c * self.srec + yield (i + self.offset, s) + + def file(self): + f = StringIO() + for i, s in self.records(): + f.write("%s\t%s\n" % (i, s)) + + f.seek(0) + return f + + def blocks(self): + f = self.file() + while True: + block = f.read(self.block_size) + if not block: + break + yield block + + async def assert_data(self): + cur = self.conn.cursor() + await cur.execute("select id, data from copy_in order by id") + for record in self.records(): + assert record == await cur.fetchone() + + assert await cur.fetchone() is None + + def sha(self, f): + m = hashlib.sha256() + while True: + block = f.read() + if not block: + break + if isinstance(block, str): + block = block.encode() + m.update(block) + return m.hexdigest() + + +class AsyncFileWriter(AsyncWriter): + def __init__(self, file): + self.file = file + + async def write(self, data): + self.file.write(data) diff --git a/tests/test_cursor.py b/tests/test_cursor.py new file mode 100644 index 0000000..a667f4f --- /dev/null +++ b/tests/test_cursor.py @@ -0,0 +1,942 @@ +import pickle +import weakref +import datetime as dt +from typing import List, Union +from contextlib import closing + +import pytest + +import psycopg +from psycopg import pq, sql, rows +from psycopg.adapt import PyFormat +from psycopg.postgres import types as builtins +from psycopg.rows import RowMaker + +from .utils import gc_collect, gc_count +from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision + + +def test_init(conn): + cur = psycopg.Cursor(conn) + cur.execute("select 1") + assert cur.fetchone() == (1,) + + conn.row_factory = rows.dict_row + cur = psycopg.Cursor(conn) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_init_factory(conn): + cur = psycopg.Cursor(conn, row_factory=rows.dict_row) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_close(conn): + cur = conn.cursor() + assert not cur.closed + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.execute("select 'foo'") + + cur.close() + assert cur.closed + + +def test_cursor_close_fetchone(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + for _ in range(5): + cur.fetchone() + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchone() + + +def test_cursor_close_fetchmany(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchmany(2)) == 2 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchmany(2) + + +def test_cursor_close_fetchall(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchall()) == 10 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchall() + + +def test_context(conn): + with conn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +def test_weakref(conn): + cur = conn.cursor() + w = weakref.ref(cur) + cur.close() + del cur + gc_collect() + assert w() is None + + +def test_pgresult(conn): + cur = conn.cursor() + cur.execute("select 1") + assert cur.pgresult + cur.close() + assert not cur.pgresult + + +def test_statusmessage(conn): + cur = conn.cursor() + assert cur.statusmessage is None + + cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + cur.execute("wat") + assert cur.statusmessage is None + + +def test_execute_many_results(conn): + cur = conn.cursor() + assert cur.nextset() is None + + rv = cur.execute("select 'foo'; select generate_series(1,3)") + assert rv is cur + assert cur.fetchall() == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.nextset() is None + + cur.close() + assert cur.nextset() is None + + +def test_execute_sequence(conn): + cur = conn.cursor() + rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +def test_execute_empty_query(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + +def test_execute_type_change(conn): + # issue #112 + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.execute(sql, (1,)) + cur.execute(sql, (100_000,)) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +def test_executemany_type_change(conn): + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.executemany(sql, [(1,), (100_000,)]) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +def test_execute_copy(conn, query): + cur = conn.cursor() + cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + cur.execute(query) + + +def test_fetchone(conn): + cur = conn.cursor() + cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = cur.fetchone() + assert row == (1, "foo", None) + row = cur.fetchone() + assert row is None + + +def test_binary_cursor_execute(conn): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None]) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +def test_execute_binary(conn): + cur = conn.cursor() + cur.execute("select %s, %s", [1, None], binary=True) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +def test_binary_cursor_text_override(conn): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None], binary=False) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +def test_query_encode(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("select '\u20ac'").fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +def test_query_badenc(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + with pytest.raises(UnicodeEncodeError): + cur.execute("select '\u20ac'") + + +@pytest.fixture(scope="session") +def _execmany(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop table if exists execmany; + create table execmany (id serial primary key, num integer, data text) + """ + ) + + +@pytest.fixture(scope="function") +def execmany(svcconn, _execmany): + cur = svcconn.cursor() + cur.execute("truncate table execmany") + + +def test_executemany(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(10, "hello"), (20, "world")] + + +def test_executemany_name(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(11, "hello"), (21, "world")] + + +def test_executemany_no_data(conn, execmany): + cur = conn.cursor() + cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +def test_executemany_rowcount(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +def test_executemany_returning(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.fetchone() == (10,) + assert cur.nextset() + assert cur.fetchone() == (20,) + assert cur.nextset() is None + + +def test_executemany_returning_discard(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + assert cur.nextset() is None + + +def test_executemany_no_result(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +def test_executemany_rowcount_no_hit(conn, execmany): + cur = conn.cursor() + cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)]) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +def test_executemany_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_executemany_null_first(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table testmany (a bigint, b bigint)") + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +def test_rowcount(conn): + cur = conn.cursor() + + cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + cur.execute("show timezone") + assert cur.rowcount == 1 + + cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)") + assert cur.rowcount == 42 + + +def test_rownumber(conn): + cur = conn.cursor() + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +@pytest.mark.parametrize("query", ["", "set timezone to utc"]) +def test_rownumber_none(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.rownumber is None + + +def test_rownumber_mixed(conn): + cur = conn.cursor() + cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) + assert cur.rownumber == 0 + assert cur.fetchone() == (1,) + assert cur.rownumber == 1 + assert cur.fetchone() == (2,) + assert cur.rownumber == 2 + cur.nextset() + assert cur.rownumber is None + cur.nextset() + assert cur.rownumber == 0 + assert cur.fetchone() == (4,) + assert cur.rownumber == 1 + + +def test_iter(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + assert list(cur) == [(1,), (2,), (3,)] + + +def test_iter_stop(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + for rec in cur: + assert rec == (1,) + break + + for rec in cur: + assert rec == (2,) + break + + assert cur.fetchone() == (3,) + assert list(cur) == [] + + +def test_row_factory(conn): + cur = conn.cursor(row_factory=my_row_factory) + + cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + cur.execute("select 'foo' as bar") + (r,) = cur.fetchone() + assert r == "FOObar" + + cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert cur.fetchall() == [["Yy", "Zz"]] + + cur.scroll(-1) + cur.row_factory = rows.dict_row + assert cur.fetchone() == {"y": "y", "z": "z"} + + +def test_row_factory_none(conn): + cur = conn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + r = cur.execute("select 1 as a, 2 as b").fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +def test_bad_row_factory(conn): + def broken_factory(cur): + 1 / 0 + + cur = conn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = conn.cursor(row_factory=broken_maker) + cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + cur.fetchone() + + +def test_scroll(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(-1) + assert cur.fetchone() == (8,) + cur.scroll(-2) + assert cur.fetchone() == (7,) + cur.scroll(2, mode="absolute") + assert cur.fetchone() == (2,) + + # on the boundary + cur.scroll(0, mode="absolute") + assert cur.fetchone() == (0,) + with pytest.raises(IndexError): + cur.scroll(-1, mode="absolute") + + cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(-1) + + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + with pytest.raises(IndexError): + cur.scroll(10, mode="absolute") + + cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(1) + + with pytest.raises(ValueError): + cur.scroll(1, "wat") + + +def test_query_params_execute(conn): + cur = conn.cursor() + assert cur._query is None + + cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select $1, $2::text" + assert cur._query.params == [b"1", None] + + cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select $1::int" + assert cur._query.params == [b"wat"] + + +def test_query_params_executemany(conn): + cur = conn.cursor() + + cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select $1, $2" + assert cur._query.params == [b"3", b"4"] + + +def test_stream(conn): + cur = conn.cursor() + recs = [] + for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +def test_stream_sql(conn): + cur = conn.cursor() + recs = list( + cur.stream( + sql.SQL( + "select i, '2021-01-01'::date + i from generate_series(1, {}) as i" + ).format(2) + ) + ) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +def test_stream_row_factory(conn): + cur = conn.cursor(row_factory=rows.dict_row) + it = iter(cur.stream("select generate_series(1,2) as a")) + assert next(it)["a"] == 1 + cur.row_factory = rows.namedtuple_row + assert next(it).a == 2 + + +def test_stream_no_row(conn): + cur = conn.cursor() + recs = list(cur.stream("select generate_series(2,1) as a")) + assert recs == [] + + +@pytest.mark.crdb_skip("no col query") +def test_stream_no_col(conn): + cur = conn.cursor() + recs = list(cur.stream("select")) + assert recs == [()] + + +@pytest.mark.parametrize( + "query", + [ + "create table test_stream_badq ()", + "copy (select 1) to stdout", + "wat?", + ], +) +def test_stream_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream(query): + pass + + +def test_stream_error_tx(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream("wat"): + pass + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_stream_error_notx(conn): + conn.autocommit = True + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream("wat"): + pass + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + +def test_stream_error_python_to_consume(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with closing(cur.stream("select generate_series(1, 10000)")) as gen: + for rec in gen: + 1 / 0 + assert conn.info.transaction_status in ( + conn.TransactionStatus.INTRANS, + conn.TransactionStatus.INERROR, + ) + + +def test_stream_error_python_consumed(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select 1") + for rec in gen: + 1 / 0 + gen.close() + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_stream_close(conn): + cur = conn.cursor() + with pytest.raises(psycopg.OperationalError): + for rec in cur.stream("select generate_series(1, 3)"): + if rec[0] == 1: + conn.close() + else: + assert False + + assert conn.closed + + +def test_stream_binary_cursor(conn): + cur = conn.cursor(binary=True) + recs = [] + for rec in cur.stream("select x::int4 from generate_series(1, 2) x"): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +def test_stream_execute_binary(conn): + cur = conn.cursor() + recs = [] + for rec in cur.stream("select x::int4 from generate_series(1, 2) x", binary=True): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +def test_stream_binary_cursor_text_override(conn): + cur = conn.cursor(binary=True) + recs = [] + for rec in cur.stream("select generate_series(1, 2)", binary=False): + recs.append(rec) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode() + + assert recs == [(1,), (2,)] + + +class TestColumn: + def test_description_attribs(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + assert len(curs.description) == 3 + for c in curs.description: + len(c) == 7 # DBAPI happy + for i, a in enumerate( + """ + name type_code display_size internal_size precision scale null_ok + """.split() + ): + assert c[i] == getattr(c, a) + + # Won't fill them up + assert c.null_ok is None + + c = curs.description[0] + assert c.name == "pi" + assert c.type_code == builtins["numeric"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision == 10 + assert c.scale == 2 + + c = curs.description[1] + assert c.name == "hi" + assert c.type_code == builtins["text"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision is None + assert c.scale is None + + c = curs.description[2] + assert c.name == "now" + assert c.type_code == builtins["date"].oid + assert c.display_size is None + if is_crdb(conn): + assert c.internal_size == 16 + else: + assert c.internal_size == 4 + assert c.precision is None + assert c.scale is None + + def test_description_slice(self, conn): + curs = conn.cursor() + curs.execute("select 1::int as a") + curs.description[0][0:2] == ("a", 23) + + @pytest.mark.parametrize( + "type, precision, scale, dsize, isize", + [ + ("text", None, None, None, None), + ("varchar", None, None, None, None), + ("varchar(42)", None, None, 42, None), + ("int4", None, None, None, 4), + ("numeric", None, None, None, None), + ("numeric(10)", 10, 0, None, None), + ("numeric(10, 3)", 10, 3, None, None), + ("time", None, None, None, 8), + crdb_time_precision("time(4)", 4, None, None, 8), + crdb_time_precision("time(10)", 6, None, None, 8), + ], + ) + def test_details(self, conn, type, precision, scale, dsize, isize): + cur = conn.cursor() + cur.execute(f"select null::{type}") + col = cur.description[0] + repr(col) + assert col.precision == precision + assert col.scale == scale + assert col.display_size == dsize + assert col.internal_size == isize + + def test_pickle(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + description = curs.description + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) + assert [tuple(d) for d in description] == [tuple(d) for d in unpickled] + + @pytest.mark.crdb_skip("no col query") + def test_no_col_query(self, conn): + cur = conn.execute("select") + assert cur.description == [] + assert cur.fetchall() == [()] + + def test_description_closed_connection(self, conn): + # If we have reasons to break this test we will (e.g. we really need + # the connection). In #172 it fails just by accident. + cur = conn.execute("select 1::int4 as foo") + conn.close() + assert len(cur.description) == 1 + col = cur.description[0] + assert col.name == "foo" + assert col.type_code == 23 + + def test_name_not_a_name(self, conn): + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone() + assert res == "x" + assert cur.description[0].name == "foo-bar" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_name_encode(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone() + assert res == "x" + assert cur.description[0].name == "\u20ac" + + +def test_str(conn): + cur = conn.cursor() + assert "psycopg.Cursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory): + faker.format = fmt + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + def work(): + with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True): + with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + cur.fetchall() + elif fetch == "iter": + for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +def my_row_factory( + cursor: Union[psycopg.Cursor[List[str]], psycopg.AsyncCursor[List[str]]] +) -> RowMaker[List[str]]: + if cursor.description is not None: + titles = [c.name for c in cursor.description] + + def mkrow(values): + return [f"{value.upper()}{title}" for title, value in zip(titles, values)] + + return mkrow + else: + return rows.no_result diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py new file mode 100644 index 0000000..ac3fdeb --- /dev/null +++ b/tests/test_cursor_async.py @@ -0,0 +1,802 @@ +import pytest +import weakref +import datetime as dt +from typing import List + +import psycopg +from psycopg import pq, sql, rows +from psycopg.adapt import PyFormat + +from .utils import gc_collect, gc_count +from .test_cursor import my_row_factory +from .test_cursor import execmany, _execmany # noqa: F401 +from .fix_crdb import crdb_encoding + +execmany = execmany # avoid F811 underneath +pytestmark = pytest.mark.asyncio + + +async def test_init(aconn): + cur = psycopg.AsyncCursor(aconn) + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + + aconn.row_factory = rows.dict_row + cur = psycopg.AsyncCursor(aconn) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_init_factory(aconn): + cur = psycopg.AsyncCursor(aconn, row_factory=rows.dict_row) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_close(aconn): + cur = aconn.cursor() + assert not cur.closed + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.execute("select 'foo'") + + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchall() + + +async def test_context(aconn): + async with aconn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +async def test_weakref(aconn): + cur = aconn.cursor() + w = weakref.ref(cur) + await cur.close() + del cur + gc_collect() + assert w() is None + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_statusmessage(aconn): + cur = aconn.cursor() + assert cur.statusmessage is None + + await cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + await cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + await cur.execute("wat") + assert cur.statusmessage is None + + +async def test_execute_many_results(aconn): + cur = aconn.cursor() + assert cur.nextset() is None + + rv = await cur.execute("select 'foo'; select generate_series(1,3)") + assert rv is cur + assert (await cur.fetchall()) == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert (await cur.fetchall()) == [(1,), (2,), (3,)] + assert cur.rowcount == 3 + assert cur.nextset() is None + + await cur.close() + assert cur.nextset() is None + + +async def test_execute_sequence(aconn): + cur = aconn.cursor() + rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +async def test_execute_empty_query(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + + +async def test_execute_type_change(aconn): + # issue #112 + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.execute(sql, (1,)) + await cur.execute(sql, (100_000,)) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +async def test_executemany_type_change(aconn): + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.executemany(sql, [(1,), (100_000,)]) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +async def test_execute_copy(aconn, query): + cur = aconn.cursor() + await cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + await cur.execute(query) + + +async def test_fetchone(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = await cur.fetchone() + assert row == (1, "foo", None) + row = await cur.fetchone() + assert row is None + + +async def test_binary_cursor_execute(aconn): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None]) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +async def test_execute_binary(aconn): + cur = aconn.cursor() + await cur.execute("select %s, %s", [1, None], binary=True) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None], binary=False) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +async def test_query_encode(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + await cur.execute("select '\u20ac'") + (res,) = await cur.fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +async def test_query_badenc(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + with pytest.raises(UnicodeEncodeError): + await cur.execute("select '\u20ac'") + + +async def test_executemany(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(10, "hello"), (20, "world")] + + +async def test_executemany_name(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(11, "hello"), (21, "world")] + + +async def test_executemany_no_data(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +async def test_executemany_rowcount(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +async def test_executemany_returning(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert (await cur.fetchone()) == (10,) + assert cur.nextset() + assert (await cur.fetchone()) == (20,) + assert cur.nextset() is None + + +async def test_executemany_returning_discard(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + assert cur.nextset() is None + + +async def test_executemany_no_result(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +async def test_executemany_rowcount_no_hit(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + await cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + await cur.executemany( + "delete from execmany where id = %s returning num", [(-1,), (-2,)] + ) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +async def test_executemany_badquery(aconn, query): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +async def test_executemany_null_first(aconn, fmt_in): + cur = aconn.cursor() + await cur.execute("create table testmany (a bigint, b bigint)") + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +async def test_rowcount(aconn): + cur = aconn.cursor() + + await cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + await cur.execute("show timezone") + assert cur.rowcount == 1 + + await cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + await cur.execute( + "insert into test_rowcount_notuples select generate_series(1, 42)" + ) + assert cur.rowcount == 42 + + +async def test_rownumber(aconn): + cur = aconn.cursor() + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + async for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(await cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +@pytest.mark.parametrize("query", ["", "set timezone to utc"]) +async def test_rownumber_none(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.rownumber is None + + +async def test_rownumber_mixed(aconn): + cur = aconn.cursor() + await cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) + assert cur.rownumber == 0 + assert await cur.fetchone() == (1,) + assert cur.rownumber == 1 + assert await cur.fetchone() == (2,) + assert cur.rownumber == 2 + cur.nextset() + assert cur.rownumber is None + cur.nextset() + assert cur.rownumber == 0 + assert await cur.fetchone() == (4,) + assert cur.rownumber == 1 + + +async def test_iter(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + res = [] + async for rec in cur: + res.append(rec) + assert res == [(1,), (2,), (3,)] + + +async def test_iter_stop(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + async for rec in cur: + assert rec == (1,) + break + + async for rec in cur: + assert rec == (2,) + break + + assert (await cur.fetchone()) == (3,) + async for rec in cur: + assert False + + +async def test_row_factory(aconn): + cur = aconn.cursor(row_factory=my_row_factory) + await cur.execute("select 'foo' as bar") + (r,) = await cur.fetchone() + assert r == "FOObar" + + await cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert await cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert await cur.fetchall() == [["Yy", "Zz"]] + + await cur.scroll(-1) + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"y": "y", "z": "z"} + + +async def test_row_factory_none(aconn): + cur = aconn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + await cur.execute("select 1 as a, 2 as b") + r = await cur.fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +async def test_bad_row_factory(aconn): + def broken_factory(cur): + 1 / 0 + + cur = aconn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + await cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = aconn.cursor(row_factory=broken_maker) + await cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + await cur.fetchone() + + +async def test_scroll(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(-1) + assert await cur.fetchone() == (8,) + await cur.scroll(-2) + assert await cur.fetchone() == (7,) + await cur.scroll(2, mode="absolute") + assert await cur.fetchone() == (2,) + + # on the boundary + await cur.scroll(0, mode="absolute") + assert await cur.fetchone() == (0,) + with pytest.raises(IndexError): + await cur.scroll(-1, mode="absolute") + + await cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(-1) + + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + with pytest.raises(IndexError): + await cur.scroll(10, mode="absolute") + + await cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(1) + + with pytest.raises(ValueError): + await cur.scroll(1, "wat") + + +async def test_query_params_execute(aconn): + cur = aconn.cursor() + assert cur._query is None + + await cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select $1, $2::text" + assert cur._query.params == [b"1", None] + + await cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + await cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select $1::int" + assert cur._query.params == [b"wat"] + + +async def test_query_params_executemany(aconn): + cur = aconn.cursor() + + await cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select $1, $2" + assert cur._query.params == [b"3", b"4"] + + +async def test_stream(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +async def test_stream_sql(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + sql.SQL( + "select i, '2021-01-01'::date + i from generate_series(1, {}) as i" + ).format(2) + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +async def test_stream_row_factory(aconn): + cur = aconn.cursor(row_factory=rows.dict_row) + ait = cur.stream("select generate_series(1,2) as a") + assert (await ait.__anext__())["a"] == 1 + cur.row_factory = rows.namedtuple_row + assert (await ait.__anext__()).a == 2 + + +async def test_stream_no_row(aconn): + cur = aconn.cursor() + recs = [rec async for rec in cur.stream("select generate_series(2,1) as a")] + assert recs == [] + + +@pytest.mark.crdb_skip("no col query") +async def test_stream_no_col(aconn): + cur = aconn.cursor() + recs = [rec async for rec in cur.stream("select")] + assert recs == [()] + + +@pytest.mark.parametrize( + "query", + [ + "create table test_stream_badq ()", + "copy (select 1) to stdout", + "wat?", + ], +) +async def test_stream_badquery(aconn, query): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream(query): + pass + + +async def test_stream_error_tx(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream("wat"): + pass + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_stream_error_notx(aconn): + await aconn.set_autocommit(True) + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream("wat"): + pass + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + +async def test_stream_error_python_to_consume(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select generate_series(1, 10000)") + async for rec in gen: + 1 / 0 + + await gen.aclose() + assert aconn.info.transaction_status in ( + aconn.TransactionStatus.INTRANS, + aconn.TransactionStatus.INERROR, + ) + + +async def test_stream_error_python_consumed(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select 1") + async for rec in gen: + 1 / 0 + + await gen.aclose() + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_stream_close(aconn): + await aconn.set_autocommit(True) + cur = aconn.cursor() + with pytest.raises(psycopg.OperationalError): + async for rec in cur.stream("select generate_series(1, 3)"): + if rec[0] == 1: + await aconn.close() + else: + assert False + + assert aconn.closed + + +async def test_stream_binary_cursor(aconn): + cur = aconn.cursor(binary=True) + recs = [] + async for rec in cur.stream("select x::int4 from generate_series(1, 2) x"): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +async def test_stream_execute_binary(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + "select x::int4 from generate_series(1, 2) x", binary=True + ): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +async def test_stream_binary_cursor_text_override(aconn): + cur = aconn.cursor(binary=True) + recs = [] + async for rec in cur.stream("select generate_series(1, 2)", binary=False): + recs.append(rec) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode() + + assert recs == [(1,), (2,)] + + +async def test_str(aconn): + cur = aconn.cursor() + assert "psycopg.AsyncCursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + await cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + await cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory): + faker.format = fmt + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + async def work(): + async with await aconn_cls.connect(dsn) as conn, conn.transaction( + force_rollback=True + ): + async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + await cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = await cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = await cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + await cur.fetchall() + elif fetch == "iter": + async for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 0000000..f50092f --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,27 @@ +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict + +pytestmark = [pytest.mark.dns] + + +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_warning(recwarn): + import_dnspython() + conninfo = "dbname=foo" + params = conninfo_to_dict(conninfo) + params = await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined] + params + ) + assert conninfo_to_dict(conninfo) == params + assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message) + + +def import_dnspython(): + try: + import dns.rdtypes.IN.A # noqa: F401 + except ImportError: + pytest.skip("dnspython package not available") + + import psycopg._dns # noqa: F401 diff --git a/tests/test_dns_srv.py b/tests/test_dns_srv.py new file mode 100644 index 0000000..15b3706 --- /dev/null +++ b/tests/test_dns_srv.py @@ -0,0 +1,149 @@ +from typing import List, Union + +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict + +from .test_dns import import_dnspython + +pytestmark = [pytest.mark.dns] + +samples_ok = [ + ("", "", None), + ("host=_pg._tcp.foo.com", "host=db1.example.com port=5432", None), + ("", "host=db1.example.com port=5432", {"PGHOST": "_pg._tcp.foo.com"}), + ( + "host=foo.com,_pg._tcp.foo.com", + "host=foo.com,db1.example.com port=,5432", + None, + ), + ( + "host=_pg._tcp.dot.com,foo.com,_pg._tcp.foo.com", + "host=foo.com,db1.example.com port=,5432", + None, + ), + ( + "host=_pg._tcp.bar.com", + "host=db1.example.com,db4.example.com,db3.example.com,db2.example.com" + " port=5432,5432,5433,5432", + None, + ), + ( + "host=service.foo.com port=srv", + "host=service.example.com port=15432", + None, + ), + # No resolution + ( + "host=_pg._tcp.foo.com hostaddr=1.1.1.1", + "host=_pg._tcp.foo.com hostaddr=1.1.1.1", + None, + ), +] + + +@pytest.mark.flakey("random weight order, might cause wrong order") +@pytest.mark.parametrize("conninfo, want, env", samples_ok) +def test_srv(conninfo, want, env, fake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + params = psycopg._dns.resolve_srv(params) # type: ignore[attr-defined] + assert conninfo_to_dict(want) == params + + +@pytest.mark.asyncio +@pytest.mark.parametrize("conninfo, want, env", samples_ok) +async def test_srv_async(conninfo, want, env, afake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + params = await ( + psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined] + ) + assert conninfo_to_dict(want) == params + + +samples_bad = [ + ("host=_pg._tcp.dot.com", None), + ("host=_pg._tcp.foo.com port=1,2", None), +] + + +@pytest.mark.parametrize("conninfo, env", samples_bad) +def test_srv_bad(conninfo, env, fake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.OperationalError): + psycopg._dns.resolve_srv(params) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("conninfo, env", samples_bad) +async def test_srv_bad_async(conninfo, env, afake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.OperationalError): + await psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined] + + +@pytest.fixture +def fake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + monkeypatch.setattr( + psycopg._dns.resolver, # type: ignore[attr-defined] + "resolve", + f, + ) + + +@pytest.fixture +def afake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + + async def af(qname, rdtype): + return f(qname, rdtype) + + monkeypatch.setattr( + psycopg._dns.async_resolver, # type: ignore[attr-defined] + "resolve", + af, + ) + + +def get_fake_srv_function(monkeypatch): + import_dnspython() + + from dns.rdtypes.IN.A import A + from dns.rdtypes.IN.SRV import SRV + from dns.exception import DNSException + + fake_hosts = { + ("_pg._tcp.dot.com", "SRV"): ["0 0 5432 ."], + ("_pg._tcp.foo.com", "SRV"): ["0 0 5432 db1.example.com."], + ("_pg._tcp.bar.com", "SRV"): [ + "1 0 5432 db2.example.com.", + "1 255 5433 db3.example.com.", + "0 0 5432 db1.example.com.", + "1 65535 5432 db4.example.com.", + ], + ("service.foo.com", "SRV"): ["0 0 15432 service.example.com."], + } + + def fake_srv_(qname, rdtype): + try: + ans = fake_hosts[qname, rdtype] + except KeyError: + raise DNSException(f"unknown test host: {qname} {rdtype}") + rv: List[Union[A, SRV]] = [] + + if rdtype == "A": + for entry in ans: + rv.append(A("IN", "A", entry)) + else: + for entry in ans: + pri, w, port, target = entry.split() + rv.append(SRV("IN", "SRV", int(pri), int(w), int(port), target)) + + return rv + + return fake_srv_ diff --git a/tests/test_encodings.py b/tests/test_encodings.py new file mode 100644 index 0000000..113f0e3 --- /dev/null +++ b/tests/test_encodings.py @@ -0,0 +1,57 @@ +import codecs +import pytest + +import psycopg +from psycopg import _encodings as encodings + + +def test_names_normalised(): + for name in encodings._py_codecs.values(): + assert codecs.lookup(name).name == name + + +@pytest.mark.parametrize( + "pyenc, pgenc", + [ + ("ascii", "SQL_ASCII"), + ("utf8", "UTF8"), + ("utf-8", "UTF8"), + ("uTf-8", "UTF8"), + ("latin9", "LATIN9"), + ("iso8859-15", "LATIN9"), + ], +) +def test_py2pg(pyenc, pgenc): + assert encodings.py2pgenc(pyenc) == pgenc.encode() + + +@pytest.mark.parametrize( + "pyenc, pgenc", + [ + ("ascii", "SQL_ASCII"), + ("utf-8", "UTF8"), + ("iso8859-15", "LATIN9"), + ], +) +def test_pg2py(pyenc, pgenc): + assert encodings.pg2pyenc(pgenc.encode()) == pyenc + + +@pytest.mark.parametrize("pgenc", ["MULE_INTERNAL", "EUC_TW"]) +def test_pg2py_missing(pgenc): + with pytest.raises(psycopg.NotSupportedError): + encodings.pg2pyenc(pgenc.encode()) + + +@pytest.mark.parametrize( + "conninfo, pyenc", + [ + ("", "utf-8"), + ("user=foo, dbname=bar", "utf-8"), + ("user=foo, dbname=bar, client_encoding=EUC_JP", "euc_jp"), + ("user=foo, dbname=bar, client_encoding=euc-jp", "euc_jp"), + ("user=foo, dbname=bar, client_encoding=WAT", "utf-8"), + ], +) +def test_conninfo_encoding(conninfo, pyenc): + assert encodings.conninfo_encoding(conninfo) == pyenc diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..23ad314 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,309 @@ +import pickle +from typing import List +from weakref import ref + +import pytest + +import psycopg +from psycopg import pq +from psycopg import errors as e + +from .utils import eur, gc_collect +from .fix_crdb import is_crdb + + +@pytest.mark.crdb_skip("severity_nonlocalized") +def test_error_diag(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + exc = excinfo.value + diag = exc.diag + assert diag.sqlstate == "42P01" + assert diag.severity_nonlocalized == "ERROR" + + +def test_diag_all_attrs(pgconn): + res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR) + diag = e.Diagnostic(res) + for d in pq.DiagnosticField: + val = getattr(diag, d.name.lower()) + assert val is None or isinstance(val, str) + + +def test_diag_right_attr(pgconn, monkeypatch): + res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR) + diag = e.Diagnostic(res) + + to_check: pq.DiagnosticField + checked: List[pq.DiagnosticField] = [] + + def check_val(self, v): + nonlocal to_check + assert to_check == v + checked.append(v) + return None + + monkeypatch.setattr(e.Diagnostic, "_error_message", check_val) + + for to_check in pq.DiagnosticField: + getattr(diag, to_check.name.lower()) + + assert len(checked) == len(pq.DiagnosticField) + + +def test_diag_attr_values(conn): + if is_crdb(conn): + conn.execute("set experimental_enable_temp_tables = 'on'") + conn.execute( + """ + create temp table test_exc ( + data int constraint chk_eq1 check (data = 1) + )""" + ) + with pytest.raises(e.Error) as exc: + conn.execute("insert into test_exc values(2)") + diag = exc.value.diag + assert diag.sqlstate == "23514" + assert diag.constraint_name == "chk_eq1" + if not is_crdb(conn): + assert diag.table_name == "test_exc" + assert diag.schema_name and diag.schema_name[:7] == "pg_temp" + assert diag.severity_nonlocalized == "ERROR" + + +@pytest.mark.crdb_skip("do") +@pytest.mark.parametrize("enc", ["utf8", "latin9"]) +def test_diag_encoding(conn, enc): + msgs = [] + conn.pgconn.exec_(b"set client_min_messages to notice") + conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary)) + conn.execute(f"set client_encoding to {enc}") + cur = conn.cursor() + cur.execute("do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql") + assert msgs == [f"hello {eur}"] + + +@pytest.mark.crdb_skip("do") +@pytest.mark.parametrize("enc", ["utf8", "latin9"]) +def test_error_encoding(conn, enc): + with conn.transaction(): + conn.execute(f"set client_encoding to {enc}") + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute( + """ + do $$begin + execute format('insert into "%s" values (1)', chr(8364)); + end$$ language plpgsql; + """ + ) + + diag = excinfo.value.diag + assert diag.message_primary and f'"{eur}"' in diag.message_primary + assert diag.sqlstate == "42P01" + + +def test_exception_class(conn): + cur = conn.cursor() + + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select * from nonexist") + + assert isinstance(excinfo.value, e.UndefinedTable) + assert isinstance(excinfo.value, conn.ProgrammingError) + + +def test_exception_class_fallback(conn): + cur = conn.cursor() + + x = e._sqlcodes.pop("42P01") + try: + with pytest.raises(e.Error) as excinfo: + cur.execute("select * from nonexist") + finally: + e._sqlcodes["42P01"] = x + + assert type(excinfo.value) is conn.ProgrammingError + + +def test_lookup(): + assert e.lookup("42P01") is e.UndefinedTable + assert e.lookup("42p01") is e.UndefinedTable + assert e.lookup("UNDEFINED_TABLE") is e.UndefinedTable + assert e.lookup("undefined_table") is e.UndefinedTable + + with pytest.raises(KeyError): + e.lookup("XXXXX") + + +def test_error_sqlstate(): + assert e.Error.sqlstate is None + assert e.ProgrammingError.sqlstate is None + assert e.UndefinedTable.sqlstate == "42P01" + + +def test_error_pickle(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + exc = pickle.loads(pickle.dumps(excinfo.value)) + assert isinstance(exc, e.UndefinedTable) + assert exc.diag.sqlstate == "42P01" + + +def test_diag_pickle(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + diag1 = excinfo.value.diag + diag2 = pickle.loads(pickle.dumps(diag1)) + + assert isinstance(diag2, type(diag1)) + for f in pq.DiagnosticField: + assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower()) + + assert diag2.sqlstate == "42P01" + + +@pytest.mark.slow +def test_diag_survives_cursor(conn): + cur = conn.cursor() + with pytest.raises(e.Error) as exc: + cur.execute("select * from nosuchtable") + + diag = exc.value.diag + del exc + w = ref(cur) + del cur + gc_collect() + assert w() is None + assert diag.sqlstate == "42P01" + + +def test_diag_independent(conn): + conn.autocommit = True + cur = conn.cursor() + + with pytest.raises(e.Error) as exc1: + cur.execute("l'acqua e' poca e 'a papera nun galleggia") + + with pytest.raises(e.Error) as exc2: + cur.execute("select level from water where ducks > 1") + + assert exc1.value.diag.sqlstate == "42601" + assert exc2.value.diag.sqlstate == "42P01" + + +@pytest.mark.crdb_skip("deferrable") +def test_diag_from_commit(conn): + cur = conn.cursor() + cur.execute( + """ + create temp table test_deferred ( + data int primary key, + ref int references test_deferred (data) + deferrable initially deferred) + """ + ) + cur.execute("insert into test_deferred values (1,2)") + with pytest.raises(e.Error) as exc: + conn.commit() + + assert exc.value.diag.sqlstate == "23503" + + +@pytest.mark.asyncio +@pytest.mark.crdb_skip("deferrable") +async def test_diag_from_commit_async(aconn): + cur = aconn.cursor() + await cur.execute( + """ + create temp table test_deferred ( + data int primary key, + ref int references test_deferred (data) + deferrable initially deferred) + """ + ) + await cur.execute("insert into test_deferred values (1,2)") + with pytest.raises(e.Error) as exc: + await aconn.commit() + + assert exc.value.diag.sqlstate == "23503" + + +def test_query_context(conn): + with pytest.raises(e.Error) as exc: + conn.execute("select * from wat") + + s = str(exc.value) + if not is_crdb(conn): + assert "from wat" in s, s + assert exc.value.diag.message_primary + assert exc.value.diag.message_primary in s + assert "ERROR" not in s + assert not s.endswith("\n") + + +@pytest.mark.crdb_skip("do") +def test_unknown_sqlstate(conn): + code = "PXX99" + with pytest.raises(KeyError): + e.lookup(code) + + with pytest.raises(e.ProgrammingError) as excinfo: + conn.execute( + f""" + do $$begin + raise exception 'made up code' using errcode = '{code}'; + end$$ language plpgsql + """ + ) + exc = excinfo.value + assert exc.diag.sqlstate == code + assert exc.sqlstate == code + # Survives pickling too + pexc = pickle.loads(pickle.dumps(exc)) + assert pexc.sqlstate == code + + +def test_pgconn_error(conn_cls): + with pytest.raises(psycopg.OperationalError) as excinfo: + conn_cls.connect("dbname=nosuchdb") + + exc = excinfo.value + assert exc.pgconn + assert exc.pgconn.db == b"nosuchdb" + + +def test_pgconn_error_pickle(conn_cls): + with pytest.raises(psycopg.OperationalError) as excinfo: + conn_cls.connect("dbname=nosuchdb") + + exc = pickle.loads(pickle.dumps(excinfo.value)) + assert exc.pgconn is None + + +def test_pgresult(conn): + with pytest.raises(e.DatabaseError) as excinfo: + conn.execute("select 1 from wat") + + exc = excinfo.value + assert exc.pgresult + assert exc.pgresult.error_field(pq.DiagnosticField.SQLSTATE) == b"42P01" + + +def test_pgresult_pickle(conn): + with pytest.raises(e.DatabaseError) as excinfo: + conn.execute("select 1 from wat") + + exc = pickle.loads(pickle.dumps(excinfo.value)) + assert exc.pgresult is None + assert exc.diag.sqlstate == "42P01" + + +def test_blank_sqlstate(conn): + assert e.get_base_exception("") is e.DatabaseError diff --git a/tests/test_generators.py b/tests/test_generators.py new file mode 100644 index 0000000..8aba73f --- /dev/null +++ b/tests/test_generators.py @@ -0,0 +1,156 @@ +from collections import deque +from functools import partial +from typing import List + +import pytest + +import psycopg +from psycopg import waiting +from psycopg import pq + + +@pytest.fixture +def pipeline(pgconn): + nb, pgconn.nonblocking = pgconn.nonblocking, True + assert pgconn.nonblocking + pgconn.enter_pipeline_mode() + yield + if pgconn.pipeline_status: + pgconn.exit_pipeline_mode() + pgconn.nonblocking = nb + + +def _run_pipeline_communicate(pgconn, generators, commands, expected_statuses): + actual_statuses: List[pq.ExecStatus] = [] + while len(actual_statuses) != len(expected_statuses): + if commands: + gen = generators.pipeline_communicate(pgconn, commands) + results = waiting.wait(gen, pgconn.socket) + for (result,) in results: + actual_statuses.append(result.status) + else: + gen = generators.fetch_many(pgconn) + results = waiting.wait(gen, pgconn.socket) + for result in results: + actual_statuses.append(result.status) + + assert actual_statuses == expected_statuses + + +@pytest.mark.pipeline +def test_pipeline_communicate_multi_pipeline(pgconn, pipeline, generators): + commands = deque( + [ + partial(pgconn.send_query_params, b"select 1", None), + pgconn.pipeline_sync, + partial(pgconn.send_query_params, b"select 2", None), + pgconn.pipeline_sync, + ] + ) + expected_statuses = [ + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.PIPELINE_SYNC, + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.PIPELINE_SYNC, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + + +@pytest.mark.pipeline +def test_pipeline_communicate_no_sync(pgconn, pipeline, generators): + numqueries = 10 + commands = deque( + [partial(pgconn.send_query_params, b"select repeat('xyzxz', 12)", None)] + * numqueries + + [pgconn.send_flush_request] + ) + expected_statuses = [pq.ExecStatus.TUPLES_OK] * numqueries + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + + +@pytest.fixture +def pipeline_demo(pgconn): + assert pgconn.pipeline_status == 0 + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_( + b"CREATE UNLOGGED TABLE pg_pipeline(" b" id serial primary key, itemno integer)" + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + yield "pg_pipeline" + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + +# TODOCRDB: 1 doesn't get rolled back. Open a ticket? +@pytest.mark.pipeline +@pytest.mark.crdb("skip", reason="pipeline aborted") +def test_pipeline_communicate_abort(pgconn, pipeline_demo, pipeline, generators): + insert_sql = b"insert into pg_pipeline(itemno) values ($1)" + commands = deque( + [ + partial(pgconn.send_query_params, insert_sql, [b"1"]), + partial(pgconn.send_query_params, b"select no_such_function(1)", None), + partial(pgconn.send_query_params, insert_sql, [b"2"]), + pgconn.pipeline_sync, + partial(pgconn.send_query_params, insert_sql, [b"3"]), + pgconn.pipeline_sync, + ] + ) + expected_statuses = [ + pq.ExecStatus.COMMAND_OK, + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_SYNC, + pq.ExecStatus.COMMAND_OK, + pq.ExecStatus.PIPELINE_SYNC, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + pgconn.exit_pipeline_mode() + res = pgconn.exec_(b"select itemno from pg_pipeline order by itemno") + assert res.ntuples == 1 + assert res.get_value(0, 0) == b"3" + + +@pytest.fixture +def pipeline_uniqviol(pgconn): + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + assert pgconn.pipeline_status == 0 + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline_uniqviol") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_( + b"CREATE UNLOGGED TABLE pg_pipeline_uniqviol(" + b" id bigint primary key, idata bigint)" + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_(b"BEGIN") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.prepare( + b"insertion", + b"insert into pg_pipeline_uniqviol values ($1, $2) returning id", + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + return "pg_pipeline_uniqviol" + + +def test_pipeline_communicate_uniqviol(pgconn, pipeline_uniqviol, pipeline, generators): + commands = deque( + [ + partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"2", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"3", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"4", b"2"]), + partial(pgconn.send_query_params, b"commit", None), + ] + ) + expected_statuses = [ + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_ABORTED, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) diff --git a/tests/test_module.py b/tests/test_module.py new file mode 100644 index 0000000..794ef0f --- /dev/null +++ b/tests/test_module.py @@ -0,0 +1,57 @@ +import pytest + +from psycopg._cmodule import _psycopg + + +@pytest.mark.parametrize( + "args, kwargs, want_conninfo", + [ + ((), {}, ""), + (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"), + ((), {"port": 15432}, "port=15432"), + ((), {"user": "foo", "dbname": None}, "user=foo"), + ], +) +def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): + # Check the main args passing from psycopg.connect to the conn generator + # Details of the params manipulation are in test_conninfo. + import psycopg.connection + + orig_connect = psycopg.connection.connect # type: ignore + + got_conninfo = None + + def mock_connect(conninfo): + nonlocal got_conninfo + got_conninfo = conninfo + return orig_connect(dsn) + + monkeypatch.setattr(psycopg.connection, "connect", mock_connect) + + conn = psycopg.connect(*args, **kwargs) + assert got_conninfo == want_conninfo + conn.close() + + +def test_version(mypy): + cp = mypy.run_on_source( + """\ +from psycopg import __version__ +assert __version__ +""" + ) + assert not cp.stdout + + +@pytest.mark.skipif(_psycopg is None, reason="C module test") +def test_version_c(mypy): + # can be psycopg_c, psycopg_binary + cpackage = _psycopg.__name__.split(".")[0] + + cp = mypy.run_on_source( + f"""\ +from {cpackage} import __version__ +assert __version__ +""" + ) + assert not cp.stdout diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..56fe598 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,577 @@ +import logging +import concurrent.futures +from typing import Any +from operator import attrgetter +from itertools import groupby + +import pytest + +import psycopg +from psycopg import pq +from psycopg import errors as e + +pytestmark = [ + pytest.mark.pipeline, + pytest.mark.skipif("not psycopg.Pipeline.is_supported()"), +] + +pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted") + + +def test_repr(conn): + with conn.pipeline() as p: + assert "psycopg.Pipeline" in repr(p) + assert "[IDLE, pipeline=ON]" in repr(p) + + conn.close() + assert "[BAD]" in repr(p) + + +def test_connection_closed(conn): + conn.close() + with pytest.raises(e.OperationalError): + with conn.pipeline(): + pass + + +def test_pipeline_status(conn: psycopg.Connection[Any]) -> None: + assert conn._pipeline is None + with conn.pipeline() as p: + assert conn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert p.status == pq.PipelineStatus.OFF + assert not conn._pipeline + + +def test_pipeline_reenter(conn: psycopg.Connection[Any]) -> None: + with conn.pipeline() as p1: + with conn.pipeline() as p2: + assert p2 is p1 + assert p1.status == pq.PipelineStatus.ON + assert p2 is p1 + assert p2.status == pq.PipelineStatus.ON + assert conn._pipeline is None + assert p1.status == pq.PipelineStatus.OFF + + +def test_pipeline_broken_conn_exit(conn: psycopg.Connection[Any]) -> None: + with pytest.raises(e.OperationalError): + with conn.pipeline(): + conn.execute("select 1") + conn.close() + closed = True + + assert closed + + +def test_pipeline_exit_error_noclobber(conn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + with conn.pipeline(): + conn.close() + 1 / 0 + + assert len(caplog.records) == 1 + + +def test_pipeline_exit_error_noclobber_nested(conn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + with conn.pipeline(): + with conn.pipeline(): + conn.close() + 1 / 0 + + assert len(caplog.records) == 2 + + +def test_pipeline_exit_sync_trace(conn, trace): + t = trace.trace(conn) + with conn.pipeline(): + pass + conn.close() + assert len([i for i in t if i.type == "Sync"]) == 1 + + +def test_pipeline_nested_sync_trace(conn, trace): + t = trace.trace(conn) + with conn.pipeline(): + with conn.pipeline(): + pass + conn.close() + assert len([i for i in t if i.type == "Sync"]) == 2 + + +def test_cursor_stream(conn): + with conn.pipeline(), conn.cursor() as cur: + with pytest.raises(psycopg.ProgrammingError): + cur.stream("select 1").__next__() + + +def test_server_cursor(conn): + with conn.cursor(name="pipeline") as cur, conn.pipeline(): + with pytest.raises(psycopg.NotSupportedError): + cur.execute("select 1") + + +def test_cannot_insert_multiple_commands(conn): + with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)): + with conn.pipeline(): + conn.execute("select 1; select 2") + + +def test_copy(conn): + with conn.pipeline(): + cur = conn.cursor() + with pytest.raises(e.NotSupportedError): + with cur.copy("copy (select 1) to stdout"): + pass + + +def test_pipeline_processed_at_exit(conn): + with conn.cursor() as cur: + with conn.pipeline() as p: + cur.execute("select 1") + + assert len(p.result_queue) == 1 + + assert cur.fetchone() == (1,) + + +def test_pipeline_errors_processed_at_exit(conn): + conn.autocommit = True + with pytest.raises(e.UndefinedTable): + with conn.pipeline(): + conn.execute("select * from nosuchtable") + conn.execute("create table voila ()") + cur = conn.execute( + "select count(*) from pg_tables where tablename = %s", ("voila",) + ) + (count,) = cur.fetchone() + assert count == 0 + + +def test_pipeline(conn): + with conn.pipeline() as p: + c1 = conn.cursor() + c2 = conn.cursor() + c1.execute("select 1") + c2.execute("select 2") + + assert len(p.result_queue) == 2 + + (r1,) = c1.fetchone() + assert r1 == 1 + + (r2,) = c2.fetchone() + assert r2 == 2 + + +def test_autocommit(conn): + conn.autocommit = True + with conn.pipeline(), conn.cursor() as c: + c.execute("select 1") + + (r,) = c.fetchone() + assert r == 1 + + +def test_pipeline_aborted(conn): + conn.autocommit = True + with conn.pipeline() as p: + c1 = conn.execute("select 1") + with pytest.raises(e.UndefinedTable): + conn.execute("select * from doesnotexist").fetchone() + with pytest.raises(e.PipelineAborted): + conn.execute("select 'aborted'").fetchone() + # Sync restore the connection in usable state. + p.sync() + c2 = conn.execute("select 2") + + (r,) = c1.fetchone() + assert r == 1 + + (r,) = c2.fetchone() + assert r == 2 + + +def test_pipeline_commit_aborted(conn): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + with conn.pipeline(): + conn.execute("select error") + conn.execute("create table voila ()") + conn.commit() + + +def test_sync_syncs_results(conn): + with conn.pipeline() as p: + cur = conn.execute("select 1") + assert cur.statusmessage is None + p.sync() + assert cur.statusmessage == "SELECT 1" + + +def test_sync_syncs_errors(conn): + conn.autocommit = True + with conn.pipeline() as p: + conn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + p.sync() + + +@pipeline_aborted +def test_errors_raised_on_commit(conn): + with conn.pipeline(): + conn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + conn.commit() + conn.rollback() + cur1 = conn.execute("select 1") + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +def test_errors_raised_on_transaction_exit(conn): + here = False + with conn.pipeline(): + with pytest.raises(e.UndefinedTable): + with conn.transaction(): + conn.execute("select 1 from nosuchtable") + here = True + cur1 = conn.execute("select 1") + assert here + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +def test_errors_raised_on_nested_transaction_exit(conn): + here = False + with conn.pipeline(): + with conn.transaction(): + with pytest.raises(e.UndefinedTable): + with conn.transaction(): + conn.execute("select 1 from nosuchtable") + here = True + cur1 = conn.execute("select 1") + assert here + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +def test_implicit_transaction(conn): + conn.autocommit = True + with conn.pipeline(): + assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE + conn.execute("select 'before'") + # Transaction is ACTIVE because previous command is not completed + # since we have not fetched its results. + assert conn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE + # Upon entering the nested pipeline through "with transaction():", a + # sync() is emitted to restore the transaction state to IDLE, as + # expected to emit a BEGIN. + with conn.transaction(): + conn.execute("select 'tx'") + cur = conn.execute("select 'after'") + assert cur.fetchone() == ("after",) + + +@pytest.mark.crdb_skip("deferrable") +def test_error_on_commit(conn): + conn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + conn.commit() + + with conn.pipeline(): + conn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + conn.commit() + cur1 = conn.execute("select 1") + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +def test_fetch_no_result(conn): + with conn.pipeline(): + cur = conn.cursor() + with pytest.raises(e.ProgrammingError): + cur.fetchone() + + +def test_executemany(conn): + conn.autocommit = True + conn.execute("drop table if exists execmanypipeline") + conn.execute( + "create unlogged table execmanypipeline (" + " id serial primary key, num integer)" + ) + with conn.pipeline(), conn.cursor() as cur: + cur.executemany( + "insert into execmanypipeline(num) values (%s) returning num", + [(10,), (20,)], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.fetchone() == (10,) + assert cur.nextset() + assert cur.fetchone() == (20,) + assert cur.nextset() is None + + +def test_executemany_no_returning(conn): + conn.autocommit = True + conn.execute("drop table if exists execmanypipelinenoreturning") + conn.execute( + "create unlogged table execmanypipelinenoreturning (" + " id serial primary key, num integer)" + ) + with conn.pipeline(), conn.cursor() as cur: + cur.executemany( + "insert into execmanypipelinenoreturning(num) values (%s)", + [(10,), (20,)], + returning=False, + ) + with pytest.raises(e.ProgrammingError, match="no result available"): + cur.fetchone() + assert cur.nextset() is None + with pytest.raises(e.ProgrammingError, match="no result available"): + cur.fetchone() + assert cur.nextset() is None + + +@pytest.mark.crdb("skip", reason="temp tables") +def test_executemany_trace(conn, trace): + conn.autocommit = True + cur = conn.cursor() + cur.execute("create temp table trace (id int)") + t = trace.trace(conn) + with conn.pipeline(): + cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + conn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] + assert len([i for i in items if i.type == "Sync"]) == 1 + + +@pytest.mark.crdb("skip", reason="temp tables") +def test_executemany_trace_returning(conn, trace): + conn.autocommit = True + cur = conn.cursor() + cur.execute("create temp table trace (id int)") + t = trace.trace(conn) + with conn.pipeline(): + cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + conn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] * 3 + assert items[-2].direction == "F" # last 2 items are F B + assert len([i for i in items if i.type == "Sync"]) == 1 + + +def test_prepared(conn): + conn.autocommit = True + with conn.pipeline(): + c1 = conn.execute("select %s::int", [10], prepare=True) + c2 = conn.execute( + "select count(*) from pg_prepared_statements where name != ''" + ) + + (r,) = c1.fetchone() + assert r == 10 + + (r,) = c2.fetchone() + assert r == 1 + + +def test_auto_prepare(conn): + conn.autocommit = True + conn.prepared_threshold = 5 + with conn.pipeline(): + cursors = [ + conn.execute("select count(*) from pg_prepared_statements where name != ''") + for i in range(10) + ] + + assert len(conn._prepared._names) == 1 + + res = [c.fetchone()[0] for c in cursors] + assert res == [0] * 5 + [1] * 5 + + +def test_transaction(conn): + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + with conn.pipeline(): + with conn.transaction(): + cur = conn.execute("select 'tx'") + + (r,) = cur.fetchone() + assert r == "tx" + + with conn.transaction(): + cur = conn.execute("select 'rb'") + raise psycopg.Rollback() + + (r,) = cur.fetchone() + assert r == "rb" + + assert not notices + + +def test_transaction_nested(conn): + with conn.pipeline(): + with conn.transaction(): + outer = conn.execute("select 'outer'") + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + inner = conn.execute("select 'inner'") + 1 / 0 + + (r,) = outer.fetchone() + assert r == "outer" + (r,) = inner.fetchone() + assert r == "inner" + + +def test_transaction_nested_no_statement(conn): + with conn.pipeline(): + with conn.transaction(): + with conn.transaction(): + cur = conn.execute("select 1") + + (r,) = cur.fetchone() + assert r == 1 + + +def test_outer_transaction(conn): + with conn.transaction(): + conn.execute("drop table if exists outertx") + with conn.transaction(): + with conn.pipeline(): + conn.execute("create table outertx as (select 1)") + cur = conn.execute("select * from outertx") + (r,) = cur.fetchone() + assert r == 1 + cur = conn.execute("select count(*) from pg_tables where tablename = 'outertx'") + assert cur.fetchone()[0] == 1 + + +def test_outer_transaction_error(conn): + with conn.transaction(): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + with conn.pipeline(): + conn.execute("select error") + conn.execute("create table voila ()") + + +def test_rollback_explicit(conn): + conn.autocommit = True + with conn.pipeline(): + with pytest.raises(e.DivisionByZero): + cur = conn.execute("select 1 / %s", [0]) + cur.fetchone() + conn.rollback() + conn.execute("select 1") + + +def test_rollback_transaction(conn): + conn.autocommit = True + with pytest.raises(e.DivisionByZero): + with conn.pipeline(): + with conn.transaction(): + cur = conn.execute("select 1 / %s", [0]) + cur.fetchone() + conn.execute("select 1") + + +def test_message_0x33(conn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + conn.autocommit = True + with conn.pipeline(): + cur = conn.execute("select 'test'") + assert cur.fetchone() == ("test",) + + assert not notices + + +def test_transaction_state_implicit_begin(conn, trace): + # Regression test to ensure that the transaction state is correct after + # the implicit BEGIN statement (in non-autocommit mode). + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + t = trace.trace(conn) + with conn.pipeline(): + conn.execute("select 'x'").fetchone() + conn.execute("select 'y'") + assert not notices + assert [ + e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0] + ] == [b' "" "BEGIN" 0'] + + +def test_concurrency(conn): + with conn.transaction(): + conn.execute("drop table if exists pipeline_concurrency") + conn.execute("drop table if exists accessed") + with conn.transaction(): + conn.execute( + "create unlogged table pipeline_concurrency (" + " id serial primary key," + " value integer" + ")" + ) + conn.execute("create unlogged table accessed as (select now() as value)") + + def update(value): + cur = conn.execute( + "insert into pipeline_concurrency(value) values (%s) returning value", + (value,), + ) + conn.execute("update accessed set value = now()") + return cur + + conn.autocommit = True + + (before,) = conn.execute("select value from accessed").fetchone() + + values = range(1, 10) + with conn.pipeline(): + with concurrent.futures.ThreadPoolExecutor() as e: + cursors = e.map(update, values, timeout=len(values)) + assert sum(cur.fetchone()[0] for cur in cursors) == sum(values) + + (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone() + assert s == sum(values) + (after,) = conn.execute("select value from accessed").fetchone() + assert after > before diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py new file mode 100644 index 0000000..2e743cf --- /dev/null +++ b/tests/test_pipeline_async.py @@ -0,0 +1,586 @@ +import asyncio +import logging +from typing import Any +from operator import attrgetter +from itertools import groupby + +import pytest + +import psycopg +from psycopg import pq +from psycopg import errors as e + +from .test_pipeline import pipeline_aborted + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.pipeline, + pytest.mark.skipif("not psycopg.AsyncPipeline.is_supported()"), +] + + +async def test_repr(aconn): + async with aconn.pipeline() as p: + assert "psycopg.AsyncPipeline" in repr(p) + assert "[IDLE, pipeline=ON]" in repr(p) + + await aconn.close() + assert "[BAD]" in repr(p) + + +async def test_connection_closed(aconn): + await aconn.close() + with pytest.raises(e.OperationalError): + async with aconn.pipeline(): + pass + + +async def test_pipeline_status(aconn: psycopg.AsyncConnection[Any]) -> None: + assert aconn._pipeline is None + async with aconn.pipeline() as p: + assert aconn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert p.status == pq.PipelineStatus.OFF + assert not aconn._pipeline + + +async def test_pipeline_reenter(aconn: psycopg.AsyncConnection[Any]) -> None: + async with aconn.pipeline() as p1: + async with aconn.pipeline() as p2: + assert p2 is p1 + assert p1.status == pq.PipelineStatus.ON + assert p2 is p1 + assert p2.status == pq.PipelineStatus.ON + assert aconn._pipeline is None + assert p1.status == pq.PipelineStatus.OFF + + +async def test_pipeline_broken_conn_exit(aconn: psycopg.AsyncConnection[Any]) -> None: + with pytest.raises(e.OperationalError): + async with aconn.pipeline(): + await aconn.execute("select 1") + await aconn.close() + closed = True + + assert closed + + +async def test_pipeline_exit_error_noclobber(aconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + async with aconn.pipeline(): + await aconn.close() + 1 / 0 + + assert len(caplog.records) == 1 + + +async def test_pipeline_exit_error_noclobber_nested(aconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + async with aconn.pipeline(): + async with aconn.pipeline(): + await aconn.close() + 1 / 0 + + assert len(caplog.records) == 2 + + +async def test_pipeline_exit_sync_trace(aconn, trace): + t = trace.trace(aconn) + async with aconn.pipeline(): + pass + await aconn.close() + assert len([i for i in t if i.type == "Sync"]) == 1 + + +async def test_pipeline_nested_sync_trace(aconn, trace): + t = trace.trace(aconn) + async with aconn.pipeline(): + async with aconn.pipeline(): + pass + await aconn.close() + assert len([i for i in t if i.type == "Sync"]) == 2 + + +async def test_cursor_stream(aconn): + async with aconn.pipeline(), aconn.cursor() as cur: + with pytest.raises(psycopg.ProgrammingError): + await cur.stream("select 1").__anext__() + + +async def test_server_cursor(aconn): + async with aconn.cursor(name="pipeline") as cur, aconn.pipeline(): + with pytest.raises(psycopg.NotSupportedError): + await cur.execute("select 1") + + +async def test_cannot_insert_multiple_commands(aconn): + with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)): + async with aconn.pipeline(): + await aconn.execute("select 1; select 2") + + +async def test_copy(aconn): + async with aconn.pipeline(): + cur = aconn.cursor() + with pytest.raises(e.NotSupportedError): + async with cur.copy("copy (select 1) to stdout") as copy: + await copy.read() + + +async def test_pipeline_processed_at_exit(aconn): + async with aconn.cursor() as cur: + async with aconn.pipeline() as p: + await cur.execute("select 1") + + assert len(p.result_queue) == 1 + + assert await cur.fetchone() == (1,) + + +async def test_pipeline_errors_processed_at_exit(aconn): + await aconn.set_autocommit(True) + with pytest.raises(e.UndefinedTable): + async with aconn.pipeline(): + await aconn.execute("select * from nosuchtable") + await aconn.execute("create table voila ()") + cur = await aconn.execute( + "select count(*) from pg_tables where tablename = %s", ("voila",) + ) + (count,) = await cur.fetchone() + assert count == 0 + + +async def test_pipeline(aconn): + async with aconn.pipeline() as p: + c1 = aconn.cursor() + c2 = aconn.cursor() + await c1.execute("select 1") + await c2.execute("select 2") + + assert len(p.result_queue) == 2 + + (r1,) = await c1.fetchone() + assert r1 == 1 + + (r2,) = await c2.fetchone() + assert r2 == 2 + + +async def test_autocommit(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(), aconn.cursor() as c: + await c.execute("select 1") + + (r,) = await c.fetchone() + assert r == 1 + + +async def test_pipeline_aborted(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline() as p: + c1 = await aconn.execute("select 1") + with pytest.raises(e.UndefinedTable): + await (await aconn.execute("select * from doesnotexist")).fetchone() + with pytest.raises(e.PipelineAborted): + await (await aconn.execute("select 'aborted'")).fetchone() + # Sync restore the connection in usable state. + await p.sync() + c2 = await aconn.execute("select 2") + + (r,) = await c1.fetchone() + assert r == 1 + + (r,) = await c2.fetchone() + assert r == 2 + + +async def test_pipeline_commit_aborted(aconn): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + async with aconn.pipeline(): + await aconn.execute("select error") + await aconn.execute("create table voila ()") + await aconn.commit() + + +async def test_sync_syncs_results(aconn): + async with aconn.pipeline() as p: + cur = await aconn.execute("select 1") + assert cur.statusmessage is None + await p.sync() + assert cur.statusmessage == "SELECT 1" + + +async def test_sync_syncs_errors(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline() as p: + await aconn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + await p.sync() + + +@pipeline_aborted +async def test_errors_raised_on_commit(aconn): + async with aconn.pipeline(): + await aconn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + await aconn.commit() + await aconn.rollback() + cur1 = await aconn.execute("select 1") + cur2 = await aconn.execute("select 2") + + assert await cur1.fetchone() == (1,) + assert await cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +async def test_errors_raised_on_transaction_exit(aconn): + here = False + async with aconn.pipeline(): + with pytest.raises(e.UndefinedTable): + async with aconn.transaction(): + await aconn.execute("select 1 from nosuchtable") + here = True + cur1 = await aconn.execute("select 1") + assert here + cur2 = await aconn.execute("select 2") + + assert await cur1.fetchone() == (1,) + assert await cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +async def test_errors_raised_on_nested_transaction_exit(aconn): + here = False + async with aconn.pipeline(): + async with aconn.transaction(): + with pytest.raises(e.UndefinedTable): + async with aconn.transaction(): + await aconn.execute("select 1 from nosuchtable") + here = True + cur1 = await aconn.execute("select 1") + assert here + cur2 = await aconn.execute("select 2") + + assert await cur1.fetchone() == (1,) + assert await cur2.fetchone() == (2,) + + +async def test_implicit_transaction(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE + await aconn.execute("select 'before'") + # Transaction is ACTIVE because previous command is not completed + # since we have not fetched its results. + assert aconn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE + # Upon entering the nested pipeline through "with transaction():", a + # sync() is emitted to restore the transaction state to IDLE, as + # expected to emit a BEGIN. + async with aconn.transaction(): + await aconn.execute("select 'tx'") + cur = await aconn.execute("select 'after'") + assert await cur.fetchone() == ("after",) + + +@pytest.mark.crdb_skip("deferrable") +async def test_error_on_commit(aconn): + await aconn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + await aconn.commit() + + async with aconn.pipeline(): + await aconn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + await aconn.commit() + cur1 = await aconn.execute("select 1") + cur2 = await aconn.execute("select 2") + + assert (await cur1.fetchone()) == (1,) + assert (await cur2.fetchone()) == (2,) + + +async def test_fetch_no_result(aconn): + async with aconn.pipeline(): + cur = aconn.cursor() + with pytest.raises(e.ProgrammingError): + await cur.fetchone() + + +async def test_executemany(aconn): + await aconn.set_autocommit(True) + await aconn.execute("drop table if exists execmanypipeline") + await aconn.execute( + "create unlogged table execmanypipeline (" + " id serial primary key, num integer)" + ) + async with aconn.pipeline(), aconn.cursor() as cur: + await cur.executemany( + "insert into execmanypipeline(num) values (%s) returning num", + [(10,), (20,)], + returning=True, + ) + assert cur.rowcount == 2 + assert (await cur.fetchone()) == (10,) + assert cur.nextset() + assert (await cur.fetchone()) == (20,) + assert cur.nextset() is None + + +async def test_executemany_no_returning(aconn): + await aconn.set_autocommit(True) + await aconn.execute("drop table if exists execmanypipelinenoreturning") + await aconn.execute( + "create unlogged table execmanypipelinenoreturning (" + " id serial primary key, num integer)" + ) + async with aconn.pipeline(), aconn.cursor() as cur: + await cur.executemany( + "insert into execmanypipelinenoreturning(num) values (%s)", + [(10,), (20,)], + returning=False, + ) + with pytest.raises(e.ProgrammingError, match="no result available"): + await cur.fetchone() + assert cur.nextset() is None + with pytest.raises(e.ProgrammingError, match="no result available"): + await cur.fetchone() + assert cur.nextset() is None + + +@pytest.mark.crdb("skip", reason="temp tables") +async def test_executemany_trace(aconn, trace): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("create temp table trace (id int)") + t = trace.trace(aconn) + async with aconn.pipeline(): + await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + await aconn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] + assert len([i for i in items if i.type == "Sync"]) == 1 + + +@pytest.mark.crdb("skip", reason="temp tables") +async def test_executemany_trace_returning(aconn, trace): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("create temp table trace (id int)") + t = trace.trace(aconn) + async with aconn.pipeline(): + await cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + await cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + await aconn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] * 3 + assert items[-2].direction == "F" # last 2 items are F B + assert len([i for i in items if i.type == "Sync"]) == 1 + + +async def test_prepared(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + c1 = await aconn.execute("select %s::int", [10], prepare=True) + c2 = await aconn.execute( + "select count(*) from pg_prepared_statements where name != ''" + ) + + (r,) = await c1.fetchone() + assert r == 10 + + (r,) = await c2.fetchone() + assert r == 1 + + +async def test_auto_prepare(aconn): + aconn.prepared_threshold = 5 + async with aconn.pipeline(): + cursors = [ + await aconn.execute( + "select count(*) from pg_prepared_statements where name != ''" + ) + for i in range(10) + ] + + assert len(aconn._prepared._names) == 1 + + res = [(await c.fetchone())[0] for c in cursors] + assert res == [0] * 5 + [1] * 5 + + +async def test_transaction(aconn): + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + async with aconn.pipeline(): + async with aconn.transaction(): + cur = await aconn.execute("select 'tx'") + + (r,) = await cur.fetchone() + assert r == "tx" + + async with aconn.transaction(): + cur = await aconn.execute("select 'rb'") + raise psycopg.Rollback() + + (r,) = await cur.fetchone() + assert r == "rb" + + assert not notices + + +async def test_transaction_nested(aconn): + async with aconn.pipeline(): + async with aconn.transaction(): + outer = await aconn.execute("select 'outer'") + with pytest.raises(ZeroDivisionError): + async with aconn.transaction(): + inner = await aconn.execute("select 'inner'") + 1 / 0 + + (r,) = await outer.fetchone() + assert r == "outer" + (r,) = await inner.fetchone() + assert r == "inner" + + +async def test_transaction_nested_no_statement(aconn): + async with aconn.pipeline(): + async with aconn.transaction(): + async with aconn.transaction(): + cur = await aconn.execute("select 1") + + (r,) = await cur.fetchone() + assert r == 1 + + +async def test_outer_transaction(aconn): + async with aconn.transaction(): + await aconn.execute("drop table if exists outertx") + async with aconn.transaction(): + async with aconn.pipeline(): + await aconn.execute("create table outertx as (select 1)") + cur = await aconn.execute("select * from outertx") + (r,) = await cur.fetchone() + assert r == 1 + cur = await aconn.execute( + "select count(*) from pg_tables where tablename = 'outertx'" + ) + assert (await cur.fetchone())[0] == 1 + + +async def test_outer_transaction_error(aconn): + async with aconn.transaction(): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + async with aconn.pipeline(): + await aconn.execute("select error") + await aconn.execute("create table voila ()") + + +async def test_rollback_explicit(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + with pytest.raises(e.DivisionByZero): + cur = await aconn.execute("select 1 / %s", [0]) + await cur.fetchone() + await aconn.rollback() + await aconn.execute("select 1") + + +async def test_rollback_transaction(aconn): + await aconn.set_autocommit(True) + with pytest.raises(e.DivisionByZero): + async with aconn.pipeline(): + async with aconn.transaction(): + cur = await aconn.execute("select 1 / %s", [0]) + await cur.fetchone() + await aconn.execute("select 1") + + +async def test_message_0x33(aconn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + await aconn.set_autocommit(True) + async with aconn.pipeline(): + cur = await aconn.execute("select 'test'") + assert (await cur.fetchone()) == ("test",) + + assert not notices + + +async def test_transaction_state_implicit_begin(aconn, trace): + # Regression test to ensure that the transaction state is correct after + # the implicit BEGIN statement (in non-autocommit mode). + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + t = trace.trace(aconn) + async with aconn.pipeline(): + await (await aconn.execute("select 'x'")).fetchone() + await aconn.execute("select 'y'") + assert not notices + assert [ + e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0] + ] == [b' "" "BEGIN" 0'] + + +async def test_concurrency(aconn): + async with aconn.transaction(): + await aconn.execute("drop table if exists pipeline_concurrency") + await aconn.execute("drop table if exists accessed") + async with aconn.transaction(): + await aconn.execute( + "create unlogged table pipeline_concurrency (" + " id serial primary key," + " value integer" + ")" + ) + await aconn.execute("create unlogged table accessed as (select now() as value)") + + async def update(value): + cur = await aconn.execute( + "insert into pipeline_concurrency(value) values (%s) returning value", + (value,), + ) + await aconn.execute("update accessed set value = now()") + return cur + + await aconn.set_autocommit(True) + + (before,) = await (await aconn.execute("select value from accessed")).fetchone() + + values = range(1, 10) + async with aconn.pipeline(): + cursors = await asyncio.wait_for( + asyncio.gather(*[update(value) for value in values]), + timeout=len(values), + ) + + assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values) + + (s,) = await ( + await aconn.execute("select sum(value) from pipeline_concurrency") + ).fetchone() + assert s == sum(values) + (after,) = await (await aconn.execute("select value from accessed")).fetchone() + assert after > before diff --git a/tests/test_prepared.py b/tests/test_prepared.py new file mode 100644 index 0000000..56c580a --- /dev/null +++ b/tests/test_prepared.py @@ -0,0 +1,277 @@ +""" +Prepared statements tests +""" + +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg.rows import namedtuple_row + + +@pytest.mark.parametrize("value", [None, 0, 3]) +def test_prepare_threshold_init(conn_cls, dsn, value): + with conn_cls.connect(dsn, prepare_threshold=value) as conn: + assert conn.prepare_threshold == value + + +def test_dont_prepare(conn): + cur = conn.cursor() + for i in range(10): + cur.execute("select %s::int", [i], prepare=False) + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +def test_do_prepare(conn): + cur = conn.cursor() + cur.execute("select %s::int", [10], prepare=True) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + + +def test_auto_prepare(conn): + res = [] + for i in range(10): + conn.execute("select %s::int", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +def test_dont_prepare_conn(conn): + for i in range(10): + conn.execute("select %s::int", [i], prepare=False) + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +def test_do_prepare_conn(conn): + conn.execute("select %s::int", [10], prepare=True) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + + +def test_auto_prepare_conn(conn): + res = [] + for i in range(10): + conn.execute("select %s", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +def test_prepare_disable(conn): + conn.prepare_threshold = None + res = [] + for i in range(10): + conn.execute("select %s", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 10 + assert not conn._prepared._names + assert not conn._prepared._counts + + +def test_no_prepare_multi(conn): + res = [] + for i in range(10): + conn.execute("select 1; select 2") + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 10 + + +def test_no_prepare_multi_with_drop(conn): + conn.execute("select 1", prepare=True) + + for i in range(10): + conn.execute("drop table if exists noprep; create table noprep()") + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +def test_no_prepare_error(conn): + conn.autocommit = True + for i in range(10): + with pytest.raises(conn.ProgrammingError): + conn.execute("select wat") + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +@pytest.mark.parametrize( + "query", + [ + "create table test_no_prepare ()", + pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")), + "set timezone = utc", + "select num from prepared_test", + "insert into prepared_test (num) values (1)", + "update prepared_test set num = num * 2", + "delete from prepared_test where num > 10", + ], +) +def test_misc_statement(conn, query): + conn.execute("create table prepared_test (num int)", prepare=False) + conn.prepare_threshold = 0 + conn.execute(query) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + + +def test_params_types(conn): + conn.execute( + "select %s, %s, %s", + [dt.date(2020, 12, 10), 42, Decimal(42)], + prepare=True, + ) + stmts = get_prepared_statements(conn) + want = [stmt.parameter_types for stmt in stmts] + assert want == [["date", "smallint", "numeric"]] + + +def test_evict_lru(conn): + conn.prepared_max = 5 + for i in range(10): + conn.execute("select 'a'") + conn.execute(f"select {i}") + + assert len(conn._prepared._names) == 1 + assert conn._prepared._names[b"select 'a'", ()] == b"_pg3_0" + for i in [9, 8, 7, 6]: + assert conn._prepared._counts[f"select {i}".encode(), ()] == 1 + + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + assert stmts[0].statement == "select 'a'" + + +def test_evict_lru_deallocate(conn): + conn.prepared_max = 5 + conn.prepare_threshold = 0 + for i in range(10): + conn.execute("select 'a'") + conn.execute(f"select {i}") + + assert len(conn._prepared._names) == 5 + for j in [9, 8, 7, 6, "'a'"]: + name = conn._prepared._names[f"select {j}".encode(), ()] + assert name.startswith(b"_pg3_") + + stmts = get_prepared_statements(conn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.statement for stmt in stmts] + assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]] + + +def test_different_types(conn): + conn.prepare_threshold = 0 + conn.execute("select %s", [None]) + conn.execute("select %s", [dt.date(2000, 1, 1)]) + conn.execute("select %s", [42]) + conn.execute("select %s", [41]) + conn.execute("select %s", [dt.date(2000, 1, 2)]) + + stmts = get_prepared_statements(conn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["text"], ["date"], ["smallint"]] + + +def test_untyped_json(conn): + conn.prepare_threshold = 1 + conn.execute("create table testjson(data jsonb)") + + for i in range(2): + conn.execute("insert into testjson (data) values (%s)", ["{}"]) + + stmts = get_prepared_statements(conn) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["jsonb"]] + + +def test_change_type_execute(conn): + conn.prepare_threshold = 0 + for i in range(3): + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + conn.rollback() + + +def test_change_type_executemany(conn): + for i in range(3): + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().executemany( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}], + ) + conn.rollback() + + +@pytest.mark.crdb("skip", reason="can't re-create a type") +def test_change_type(conn): + conn.prepare_threshold = 0 + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + conn.execute("DROP TABLE preptable") + conn.execute("DROP TYPE prepenum") + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + + stmts = get_prepared_statements(conn) + assert len(stmts) == 3 + + +def test_change_type_savepoint(conn): + conn.prepare_threshold = 0 + with conn.transaction(): + for i in range(3): + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + raise ZeroDivisionError() + + +def get_prepared_statements(conn): + cur = conn.cursor(row_factory=namedtuple_row) + cur.execute( + # CRDB has 'PREPARE name AS' in the statement. + r""" +select name, + regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement, + prepare_time, + parameter_types +from pg_prepared_statements +where name != '' + """, + prepare=False, + ) + return cur.fetchall() diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py new file mode 100644 index 0000000..84d948f --- /dev/null +++ b/tests/test_prepared_async.py @@ -0,0 +1,207 @@ +""" +Prepared statements tests on async connections +""" + +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg.rows import namedtuple_row + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.parametrize("value", [None, 0, 3]) +async def test_prepare_threshold_init(aconn_cls, dsn, value): + async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn: + assert conn.prepare_threshold == value + + +async def test_dont_prepare(aconn): + cur = aconn.cursor() + for i in range(10): + await cur.execute("select %s::int", [i], prepare=False) + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +async def test_do_prepare(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int", [10], prepare=True) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_auto_prepare(aconn): + res = [] + for i in range(10): + await aconn.execute("select %s::int", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +async def test_dont_prepare_conn(aconn): + for i in range(10): + await aconn.execute("select %s::int", [i], prepare=False) + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +async def test_do_prepare_conn(aconn): + await aconn.execute("select %s::int", [10], prepare=True) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_auto_prepare_conn(aconn): + res = [] + for i in range(10): + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +async def test_prepare_disable(aconn): + aconn.prepare_threshold = None + res = [] + for i in range(10): + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 10 + assert not aconn._prepared._names + assert not aconn._prepared._counts + + +async def test_no_prepare_multi(aconn): + res = [] + for i in range(10): + await aconn.execute("select 1; select 2") + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 10 + + +async def test_no_prepare_error(aconn): + await aconn.set_autocommit(True) + for i in range(10): + with pytest.raises(aconn.ProgrammingError): + await aconn.execute("select wat") + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +@pytest.mark.parametrize( + "query", + [ + "create table test_no_prepare ()", + pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")), + "set timezone = utc", + "select num from prepared_test", + "insert into prepared_test (num) values (1)", + "update prepared_test set num = num * 2", + "delete from prepared_test where num > 10", + ], +) +async def test_misc_statement(aconn, query): + await aconn.execute("create table prepared_test (num int)", prepare=False) + aconn.prepare_threshold = 0 + await aconn.execute(query) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_params_types(aconn): + await aconn.execute( + "select %s, %s, %s", + [dt.date(2020, 12, 10), 42, Decimal(42)], + prepare=True, + ) + stmts = await get_prepared_statements(aconn) + want = [stmt.parameter_types for stmt in stmts] + assert want == [["date", "smallint", "numeric"]] + + +async def test_evict_lru(aconn): + aconn.prepared_max = 5 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared._names) == 1 + assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0" + for i in [9, 8, 7, 6]: + assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1 + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + assert stmts[0].statement == "select 'a'" + + +async def test_evict_lru_deallocate(aconn): + aconn.prepared_max = 5 + aconn.prepare_threshold = 0 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared._names) == 5 + for j in [9, 8, 7, 6, "'a'"]: + name = aconn._prepared._names[f"select {j}".encode(), ()] + assert name.startswith(b"_pg3_") + + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.statement for stmt in stmts] + assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]] + + +async def test_different_types(aconn): + aconn.prepare_threshold = 0 + await aconn.execute("select %s", [None]) + await aconn.execute("select %s", [dt.date(2000, 1, 1)]) + await aconn.execute("select %s", [42]) + await aconn.execute("select %s", [41]) + await aconn.execute("select %s", [dt.date(2000, 1, 2)]) + + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["text"], ["date"], ["smallint"]] + + +async def test_untyped_json(aconn): + aconn.prepare_threshold = 1 + await aconn.execute("create table testjson(data jsonb)") + for i in range(2): + await aconn.execute("insert into testjson (data) values (%s)", ["{}"]) + + stmts = await get_prepared_statements(aconn) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["jsonb"]] + + +async def get_prepared_statements(aconn): + cur = aconn.cursor(row_factory=namedtuple_row) + await cur.execute( + r""" +select name, + regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement, + prepare_time, + parameter_types +from pg_prepared_statements +where name != '' + """, + prepare=False, + ) + return await cur.fetchall() diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py new file mode 100644 index 0000000..82a5d73 --- /dev/null +++ b/tests/test_psycopg_dbapi20.py @@ -0,0 +1,164 @@ +import pytest +import datetime as dt +from typing import Any, Dict + +import psycopg +from psycopg.conninfo import conninfo_to_dict + +from . import dbapi20 +from . import dbapi20_tpc + + +@pytest.fixture(scope="class") +def with_dsn(request, session_dsn): + request.cls.connect_args = (session_dsn,) + + +@pytest.mark.usefixtures("with_dsn") +class PsycopgTests(dbapi20.DatabaseAPI20Test): + driver = psycopg + # connect_args = () # set by the fixture + connect_kw_args: Dict[str, Any] = {} + + def test_nextset(self): + # tested elsewhere + pass + + def test_setoutputsize(self): + # no-op + pass + + +@pytest.mark.usefixtures("tpc") +@pytest.mark.usefixtures("with_dsn") +class PsycopgTPCTests(dbapi20_tpc.TwoPhaseCommitTests): + driver = psycopg + connect_args = () # set by the fixture + + def connect(self): + return psycopg.connect(*self.connect_args) + + +# Shut up warnings +PsycopgTests.failUnless = PsycopgTests.assertTrue +PsycopgTPCTests.assertEquals = PsycopgTPCTests.assertEqual + + +@pytest.mark.parametrize( + "typename, singleton", + [ + ("bytea", "BINARY"), + ("date", "DATETIME"), + ("timestamp without time zone", "DATETIME"), + ("timestamp with time zone", "DATETIME"), + ("time without time zone", "DATETIME"), + ("time with time zone", "DATETIME"), + ("interval", "DATETIME"), + ("integer", "NUMBER"), + ("smallint", "NUMBER"), + ("bigint", "NUMBER"), + ("real", "NUMBER"), + ("double precision", "NUMBER"), + ("numeric", "NUMBER"), + ("decimal", "NUMBER"), + ("oid", "ROWID"), + ("varchar", "STRING"), + ("char", "STRING"), + ("text", "STRING"), + ], +) +def test_singletons(conn, typename, singleton): + singleton = getattr(psycopg, singleton) + cur = conn.cursor() + cur.execute(f"select null::{typename}") + oid = cur.description[0].type_code + assert singleton == oid + assert oid == singleton + assert singleton != oid + 10000 + assert oid + 10000 != singleton + + +@pytest.mark.parametrize( + "ticks, want", + [ + (0, "1970-01-01T00:00:00.000000+0000"), + (1273173119.99992, "2010-05-06T14:11:59.999920-0500"), + ], +) +def test_timestamp_from_ticks(ticks, want): + s = psycopg.TimestampFromTicks(ticks) + want = dt.datetime.strptime(want, "%Y-%m-%dT%H:%M:%S.%f%z") + assert s == want + + +@pytest.mark.parametrize( + "ticks, want", + [ + (0, "1970-01-01"), + # Returned date is local + (1273173119.99992, ["2010-05-06", "2010-05-07"]), + ], +) +def test_date_from_ticks(ticks, want): + s = psycopg.DateFromTicks(ticks) + if isinstance(want, str): + want = [want] + want = [dt.datetime.strptime(w, "%Y-%m-%d").date() for w in want] + assert s in want + + +@pytest.mark.parametrize( + "ticks, want", + [(0, "00:00:00.000000"), (1273173119.99992, "00:11:59.999920")], +) +def test_time_from_ticks(ticks, want): + s = psycopg.TimeFromTicks(ticks) + want = dt.datetime.strptime(want, "%H:%M:%S.%f").time() + assert s.replace(hour=0) == want + + +@pytest.mark.parametrize( + "args, kwargs, want", + [ + ((), {}, ""), + (("",), {}, ""), + (("host=foo user=bar",), {}, "host=foo user=bar"), + (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + ( + ("host=foo port=5432",), + {"host": "qux", "user": "joe"}, + "host=qux user=joe port=5432", + ), + (("host=foo",), {"user": None}, "host=foo"), + ], +) +def test_connect_args(monkeypatch, pgconn, args, kwargs, want): + the_conninfo: str + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + conn = psycopg.connect(*args, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + conn.close() + + +@pytest.mark.parametrize( + "args, kwargs, exctype", + [ + (("host=foo", "host=bar"), {}, TypeError), + (("", ""), {}, TypeError), + ((), {"nosuchparam": 42}, psycopg.ProgrammingError), + ], +) +def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + + with pytest.raises(exctype): + psycopg.connect(*args, **kwargs) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..7263a80 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,162 @@ +import pytest + +import psycopg +from psycopg import pq +from psycopg.adapt import Transformer, PyFormat +from psycopg._queries import PostgresQuery, _split_query + + +@pytest.mark.parametrize( + "input, want", + [ + (b"", [(b"", 0, PyFormat.AUTO)]), + (b"foo bar", [(b"foo bar", 0, PyFormat.AUTO)]), + (b"foo %% bar", [(b"foo % bar", 0, PyFormat.AUTO)]), + (b"%s", [(b"", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]), + (b"%s foo", [(b"", 0, PyFormat.AUTO), (b" foo", 0, PyFormat.AUTO)]), + (b"%b foo", [(b"", 0, PyFormat.BINARY), (b" foo", 0, PyFormat.AUTO)]), + (b"foo %s", [(b"foo ", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]), + ( + b"foo %%%s bar", + [(b"foo %", 0, PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)], + ), + ( + b"foo %(name)s bar", + [(b"foo ", "name", PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)], + ), + ( + b"foo %(name)s %(name)b bar", + [ + (b"foo ", "name", PyFormat.AUTO), + (b" ", "name", PyFormat.BINARY), + (b" bar", 0, PyFormat.AUTO), + ], + ), + ( + b"foo %s%b bar %s baz", + [ + (b"foo ", 0, PyFormat.AUTO), + (b"", 1, PyFormat.BINARY), + (b" bar ", 2, PyFormat.AUTO), + (b" baz", 0, PyFormat.AUTO), + ], + ), + ], +) +def test_split_query(input, want): + assert _split_query(input) == want + + +@pytest.mark.parametrize( + "input", + [ + b"foo %d bar", + b"foo % bar", + b"foo %%% bar", + b"foo %(foo)d bar", + b"foo %(foo)s bar %s baz", + b"foo %(foo) bar", + b"foo %(foo bar", + b"3%2", + ], +) +def test_split_query_bad(input): + with pytest.raises(psycopg.ProgrammingError): + _split_query(input) + + +@pytest.mark.parametrize( + "query, params, want, wformats, wparams", + [ + (b"", None, b"", None, None), + (b"", [], b"", [], []), + (b"%%", [], b"%", [], []), + (b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]), + ( + b"%t %% %t", + (1, 2), + b"$1 % $2", + [pq.Format.TEXT, pq.Format.TEXT], + [b"1", b"2"], + ), + ( + b"%t %% %t", + ("a", 2), + b"$1 % $2", + [pq.Format.TEXT, pq.Format.TEXT], + [b"a", b"2"], + ), + ], +) +def test_pg_query_seq(query, params, want, wformats, wparams): + pq = PostgresQuery(Transformer()) + pq.convert(query, params) + assert pq.query == want + assert pq.formats == wformats + assert pq.params == wparams + + +@pytest.mark.parametrize( + "query, params, want, wformats, wparams", + [ + (b"", {}, b"", [], []), + (b"hello %%", {"a": 1}, b"hello %", [], []), + ( + b"select %(hello)t", + {"hello": 1, "world": 2}, + b"select $1", + [pq.Format.TEXT], + [b"1"], + ), + ( + b"select %(hi)s %(there)s %(hi)s", + {"hi": 0, "there": "a"}, + b"select $1 $2 $1", + [pq.Format.BINARY, pq.Format.TEXT], + [b"\x00" * 2, b"a"], + ), + ], +) +def test_pg_query_map(query, params, want, wformats, wparams): + pq = PostgresQuery(Transformer()) + pq.convert(query, params) + assert pq.query == want + assert pq.formats == wformats + assert pq.params == wparams + + +@pytest.mark.parametrize( + "query, params", + [ + (b"select %s", {"a": 1}), + (b"select %(name)s", [1]), + (b"select %s", "a"), + (b"select %s", 1), + (b"select %s", b"a"), + (b"select %s", set()), + ], +) +def test_pq_query_badtype(query, params): + pq = PostgresQuery(Transformer()) + with pytest.raises(TypeError): + pq.convert(query, params) + + +@pytest.mark.parametrize( + "query, params", + [ + (b"", [1]), + (b"%s", []), + (b"%%", [1]), + (b"$1", [1]), + (b"select %(", {"a": 1}), + (b"select %(a", {"a": 1}), + (b"select %(a)", {"a": 1}), + (b"select %s %(hi)s", [1]), + (b"select %(hi)s %(hi)b", {"hi": 1}), + ], +) +def test_pq_query_badprog(query, params): + pq = PostgresQuery(Transformer()) + with pytest.raises(psycopg.ProgrammingError): + pq.convert(query, params) diff --git a/tests/test_rows.py b/tests/test_rows.py new file mode 100644 index 0000000..5165b80 --- /dev/null +++ b/tests/test_rows.py @@ -0,0 +1,167 @@ +import pytest + +import psycopg +from psycopg import rows + +from .utils import eur + + +def test_tuple_row(conn): + conn.row_factory = rows.dict_row + assert conn.execute("select 1 as a").fetchone() == {"a": 1} + cur = conn.cursor(row_factory=rows.tuple_row) + row = cur.execute("select 1 as a").fetchone() + assert row == (1,) + assert type(row) is tuple + assert cur._make_row is tuple + + +def test_dict_row(conn): + cur = conn.cursor(row_factory=rows.dict_row) + cur.execute("select 'bob' as name, 3 as id") + assert cur.fetchall() == [{"name": "bob", "id": 3}] + + cur.execute("select 'a' as letter; select 1 as number") + assert cur.fetchall() == [{"letter": "a"}] + assert cur.nextset() + assert cur.fetchall() == [{"number": 1}] + assert not cur.nextset() + + +def test_namedtuple_row(conn): + rows._make_nt.cache_clear() + cur = conn.cursor(row_factory=rows.namedtuple_row) + cur.execute("select 'bob' as name, 3 as id") + (person1,) = cur.fetchall() + assert f"{person1.name} {person1.id}" == "bob 3" + + ci1 = rows._make_nt.cache_info() + assert ci1.hits == 0 and ci1.misses == 1 + + cur.execute("select 'alice' as name, 1 as id") + (person2,) = cur.fetchall() + assert type(person2) is type(person1) + + ci2 = rows._make_nt.cache_info() + assert ci2.hits == 1 and ci2.misses == 1 + + cur.execute("select 'foo', 1 as id") + (r0,) = cur.fetchall() + assert r0.f_column_ == "foo" + assert r0.id == 1 + + cur.execute("select 'a' as letter; select 1 as number") + (r1,) = cur.fetchall() + assert r1.letter == "a" + assert cur.nextset() + (r2,) = cur.fetchall() + assert r2.number == 1 + assert not cur.nextset() + assert type(r1) is not type(r2) + + cur.execute(f'select 1 as üåäö, 2 as _, 3 as "123", 4 as "a-b", 5 as "{eur}eur"') + (r3,) = cur.fetchall() + assert r3.üåäö == 1 + assert r3.f_ == 2 + assert r3.f123 == 3 + assert r3.a_b == 4 + assert r3.f_eur == 5 + + +def test_class_row(conn): + cur = conn.cursor(row_factory=rows.class_row(Person)) + cur.execute("select 'John' as first, 'Doe' as last") + (p,) = cur.fetchall() + assert isinstance(p, Person) + assert p.first == "John" + assert p.last == "Doe" + assert p.age is None + + for query in ( + "select 'John' as first", + "select 'John' as first, 'Doe' as last, 42 as wat", + ): + cur.execute(query) + with pytest.raises(TypeError): + cur.fetchone() + + +def test_args_row(conn): + cur = conn.cursor(row_factory=rows.args_row(argf)) + cur.execute("select 'John' as first, 'Doe' as last") + assert cur.fetchone() == "JohnDoe" + + +def test_kwargs_row(conn): + cur = conn.cursor(row_factory=rows.kwargs_row(kwargf)) + cur.execute("select 'John' as first, 'Doe' as last") + (p,) = cur.fetchall() + assert isinstance(p, Person) + assert p.first == "John" + assert p.last == "Doe" + assert p.age == 42 + + +@pytest.mark.parametrize( + "factory", + "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(), +) +def test_no_result(factory, conn): + cur = conn.cursor(row_factory=factory_from_name(factory)) + cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + +@pytest.mark.crdb_skip("no col query") +@pytest.mark.parametrize( + "factory", "tuple_row dict_row namedtuple_row args_row".split() +) +def test_no_column(factory, conn): + cur = conn.cursor(row_factory=factory_from_name(factory)) + cur.execute("select") + recs = cur.fetchall() + assert len(recs) == 1 + assert not recs[0] + + +@pytest.mark.crdb("skip") +def test_no_column_class_row(conn): + class Empty: + def __init__(self, x=10, y=20): + self.x = x + self.y = y + + cur = conn.cursor(row_factory=rows.class_row(Empty)) + cur.execute("select") + x = cur.fetchone() + assert isinstance(x, Empty) + assert x.x == 10 + assert x.y == 20 + + +def factory_from_name(name): + factory = getattr(rows, name) + if factory is rows.class_row: + factory = factory(Person) + if factory is rows.args_row: + factory = factory(argf) + if factory is rows.kwargs_row: + factory = factory(argf) + + return factory + + +class Person: + def __init__(self, first, last, age=None): + self.first = first + self.last = last + self.age = age + + +def argf(*args): + return "".join(map(str, args)) + + +def kwargf(**kwargs): + return Person(**kwargs, age=42) diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py new file mode 100644 index 0000000..f7b6c8e --- /dev/null +++ b/tests/test_server_cursor.py @@ -0,0 +1,525 @@ +import pytest + +import psycopg +from psycopg import rows, errors as e +from psycopg.pq import Format + +pytestmark = pytest.mark.crdb_skip("server-side cursor") + + +def test_init_row_factory(conn): + with psycopg.ServerCursor(conn, "foo") as cur: + assert cur.name == "foo" + assert cur.connection is conn + assert cur.row_factory is conn.row_factory + + conn.row_factory = rows.dict_row + + with psycopg.ServerCursor(conn, "bar") as cur: + assert cur.name == "bar" + assert cur.row_factory is rows.dict_row # type: ignore + + with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur: + assert cur.name == "baz" + assert cur.row_factory is rows.namedtuple_row # type: ignore + + +def test_init_params(conn): + with psycopg.ServerCursor(conn, "foo") as cur: + assert cur.scrollable is None + assert cur.withhold is False + + with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur: + assert cur.scrollable is False + assert cur.withhold is True + + +@pytest.mark.crdb_skip("cursor invalid name") +def test_funny_name(conn): + cur = conn.cursor("1-2-3") + cur.execute("select generate_series(1, 3) as bar") + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.name == "1-2-3" + cur.close() + + +def test_repr(conn): + cur = conn.cursor("my-name") + assert "psycopg.ServerCursor" in str(cur) + assert "my-name" in repr(cur) + cur.close() + + +def test_connection(conn): + cur = conn.cursor("foo") + assert cur.connection is conn + cur.close() + + +def test_description(conn): + cur = conn.cursor("foo") + assert cur.name == "foo" + cur.execute("select generate_series(1, 10)::int4 as bar") + assert len(cur.description) == 1 + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["int4"].oid + assert cur.pgresult.ntuples == 0 + cur.close() + + +def test_format(conn): + cur = conn.cursor("foo") + assert cur.format == Format.TEXT + cur.close() + + cur = conn.cursor("foo", binary=True) + assert cur.format == Format.BINARY + cur.close() + + +def test_query_params(conn): + with conn.cursor("foo") as cur: + assert cur._query is None + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur._query + assert b"declare" in cur._query.query.lower() + assert b"(1, $1)" in cur._query.query.lower() + assert cur._query.params == [bytes([0, 3])] # 3 as binary int2 + + +def test_binary_cursor_execute(conn): + cur = conn.cursor("foo", binary=True) + cur.execute("select generate_series(1, 2)::int4") + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert cur.fetchone() == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + cur.close() + + +def test_execute_binary(conn): + cur = conn.cursor("foo") + cur.execute("select generate_series(1, 2)::int4", binary=True) + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert cur.fetchone() == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + + cur.execute("select generate_series(1, 1)::int4") + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + cur.close() + + +def test_binary_cursor_text_override(conn): + cur = conn.cursor("foo", binary=True) + cur.execute("select generate_series(1, 2)", binary=False) + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.fetchone() == (2,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"2" + + cur.execute("select generate_series(1, 2)::int4") + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + cur.close() + + +def test_close(conn, recwarn): + if conn.info.transaction_status == conn.TransactionStatus.INTRANS: + # connection dirty from previous failure + conn.execute("close foo") + recwarn.clear() + cur = conn.cursor("foo") + cur.execute("select generate_series(1, 10) as bar") + cur.close() + assert cur.closed + + assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +def test_close_idempotent(conn): + cur = conn.cursor("foo") + cur.execute("select 1") + cur.fetchall() + cur.close() + cur.close() + + +def test_close_broken_conn(conn): + cur = conn.cursor("foo") + conn.close() + cur.close() + assert cur.closed + + +def test_cursor_close_fetchone(conn): + cur = conn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + for _ in range(5): + cur.fetchone() + + cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + cur.fetchone() + + +def test_cursor_close_fetchmany(conn): + cur = conn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchmany(2)) == 2 + + cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + cur.fetchmany(2) + + +def test_cursor_close_fetchall(conn): + cur = conn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchall()) == 10 + + cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + cur.fetchall() + + +def test_close_noop(conn, recwarn): + recwarn.clear() + cur = conn.cursor("foo") + cur.close() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +def test_close_on_error(conn): + cur = conn.cursor("foo") + cur.execute("select 1") + with pytest.raises(e.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + cur.close() + + +def test_pgresult(conn): + cur = conn.cursor() + cur.execute("select 1") + assert cur.pgresult + cur.close() + assert not cur.pgresult + + +def test_context(conn, recwarn): + recwarn.clear() + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, 10) as bar") + + assert cur.closed + assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +def test_close_no_clobber(conn): + with pytest.raises(e.DivisionByZero): + with conn.cursor("foo") as cur: + cur.execute("select 1 / %s", (0,)) + cur.fetchall() + + +def test_warn_close(conn, recwarn): + recwarn.clear() + cur = conn.cursor("foo") + cur.execute("select generate_series(1, 10) as bar") + del cur + assert ".close()" in str(recwarn.pop(ResourceWarning).message) + + +def test_execute_reuse(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as foo", (3,)) + assert cur.fetchone() == (1,) + + cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world")) + assert cur.fetchone() == ("hello", "world") + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["text"].oid + assert cur.description[1].name == "baz" + + +@pytest.mark.parametrize( + "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"] +) +def test_execute_error(conn, stmt): + cur = conn.cursor("foo") + with pytest.raises(e.ProgrammingError): + cur.execute(stmt) + cur.close() + + +def test_executemany(conn): + cur = conn.cursor("foo") + with pytest.raises(e.NotSupportedError): + cur.executemany("select %s", [(1,), (2,)]) + cur.close() + + +def test_fetchone(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (2,)) + assert cur.fetchone() == (1,) + assert cur.fetchone() == (2,) + assert cur.fetchone() is None + + +def test_fetchmany(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (5,)) + assert cur.fetchmany(3) == [(1,), (2,), (3,)] + assert cur.fetchone() == (4,) + assert cur.fetchmany(3) == [(5,)] + assert cur.fetchmany(3) == [] + + +def test_fetchall(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.fetchall() == [] + + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchone() == (1,) + assert cur.fetchall() == [(2,), (3,)] + assert cur.fetchall() == [] + + +def test_nextset(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert not cur.nextset() + + +def test_no_result(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar where false", (3,)) + assert len(cur.description) == 1 + assert cur.fetchall() == [] + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_standard_row_factory(conn, row_factory): + if row_factory == "tuple_row": + getter = lambda r: r[0] # noqa: E731 + elif row_factory == "dict_row": + getter = lambda r: r["bar"] # noqa: E731 + elif row_factory == "namedtuple_row": + getter = lambda r: r.bar # noqa: E731 + else: + assert False, row_factory + + row_factory = getattr(rows, row_factory) + with conn.cursor("foo", row_factory=row_factory) as cur: + cur.execute("select generate_series(1, 5) as bar") + assert getter(cur.fetchone()) == 1 + assert list(map(getter, cur.fetchmany(2))) == [2, 3] + assert list(map(getter, cur.fetchall())) == [4, 5] + + +@pytest.mark.crdb_skip("scroll cursor") +def test_row_factory(conn): + n = 0 + + def my_row_factory(cur): + nonlocal n + n += 1 + return lambda values: [n] + [-v for v in values] + + cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True) + cur.execute("select generate_series(1, 3) as x") + recs = cur.fetchall() + cur.scroll(0, "absolute") + while True: + rec = cur.fetchone() + if not rec: + break + recs.append(rec) + assert recs == [[1, -1], [1, -2], [1, -3]] * 2 + + cur.scroll(0, "absolute") + cur.row_factory = rows.dict_row + assert cur.fetchone() == {"x": 1} + cur.close() + + +def test_rownumber(conn): + cur = conn.cursor("foo") + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + cur.fetchall() + assert cur.rownumber == 42 + cur.close() + + +def test_iter(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + recs = list(cur) + assert recs == [(1,), (2,), (3,)] + + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchone() == (1,) + recs = list(cur) + assert recs == [(2,), (3,)] + + +def test_iter_rownumber(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + for row in cur: + assert cur.rownumber == row[0] + + +def test_itersize(conn, commands): + with conn.cursor("foo") as cur: + assert cur.itersize == 100 + cur.itersize = 2 + cur.execute("select generate_series(1, %s) as bar", (3,)) + commands.popall() # flush begin and other noise + + list(cur) + cmds = commands.popall() + assert len(cmds) == 2 + for cmd in cmds: + assert "fetch forward 2" in cmd.lower() + + +def test_cant_scroll_by_default(conn): + cur = conn.cursor("tmp") + assert cur.scrollable is None + with pytest.raises(e.ProgrammingError): + cur.scroll(0) + cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +def test_scroll(conn): + cur = conn.cursor("tmp", scrollable=True) + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + + with pytest.raises(ValueError): + cur.scroll(9, mode="wat") + cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +def test_scrollable(conn): + curs = conn.cursor("foo", scrollable=True) + assert curs.scrollable is True + curs.execute("select generate_series(0, 5)") + curs.scroll(5) + for i in range(4, -1, -1): + curs.scroll(-1) + assert i == curs.fetchone()[0] + curs.scroll(-1) + curs.close() + + +def test_non_scrollable(conn): + curs = conn.cursor("foo", scrollable=False) + assert curs.scrollable is False + curs.execute("select generate_series(0, 5)") + curs.scroll(5) + with pytest.raises(e.OperationalError): + curs.scroll(-1) + curs.close() + + +@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}]) +def test_no_hold(conn, kwargs): + with conn.cursor("foo", **kwargs) as curs: + assert curs.withhold is False + curs.execute("select generate_series(0, 2)") + assert curs.fetchone() == (0,) + conn.commit() + with pytest.raises(e.InvalidCursorName): + curs.fetchone() + + +@pytest.mark.crdb_skip("cursor with hold") +def test_hold(conn): + with conn.cursor("foo", withhold=True) as curs: + assert curs.withhold is True + curs.execute("select generate_series(0, 5)") + assert curs.fetchone() == (0,) + conn.commit() + assert curs.fetchone() == (1,) + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"]) +def test_steal_cursor(conn, row_factory): + cur1 = conn.cursor() + cur1.execute("declare test cursor for select generate_series(1, 6) as s") + + cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory)) + # can call fetch without execute + rec = cur2.fetchone() + assert rec == (1,) + if row_factory == "namedtuple_row": + assert rec.s == 1 + assert cur2.fetchmany(3) == [(2,), (3,), (4,)] + assert cur2.fetchall() == [(5,), (6,)] + cur2.close() + + +def test_stolen_cursor_close(conn): + cur1 = conn.cursor() + cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = conn.cursor("test") + cur2.close() + + cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = conn.cursor("test") + cur2.close() diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py new file mode 100644 index 0000000..21b4345 --- /dev/null +++ b/tests/test_server_cursor_async.py @@ -0,0 +1,543 @@ +import pytest + +import psycopg +from psycopg import rows, errors as e +from psycopg.pq import Format + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("server-side cursor"), +] + + +async def test_init_row_factory(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.name == "foo" + assert cur.connection is aconn + assert cur.row_factory is aconn.row_factory + + aconn.row_factory = rows.dict_row + + async with psycopg.AsyncServerCursor(aconn, "bar") as cur: + assert cur.name == "bar" + assert cur.row_factory is rows.dict_row # type: ignore + + async with psycopg.AsyncServerCursor( + aconn, "baz", row_factory=rows.namedtuple_row + ) as cur: + assert cur.name == "baz" + assert cur.row_factory is rows.namedtuple_row # type: ignore + + +async def test_init_params(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.scrollable is None + assert cur.withhold is False + + async with psycopg.AsyncServerCursor( + aconn, "bar", withhold=True, scrollable=False + ) as cur: + assert cur.scrollable is False + assert cur.withhold is True + + +@pytest.mark.crdb_skip("cursor invalid name") +async def test_funny_name(aconn): + cur = aconn.cursor("1-2-3") + await cur.execute("select generate_series(1, 3) as bar") + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert cur.name == "1-2-3" + await cur.close() + + +async def test_repr(aconn): + cur = aconn.cursor("my-name") + assert "psycopg.AsyncServerCursor" in str(cur) + assert "my-name" in repr(cur) + await cur.close() + + +async def test_connection(aconn): + cur = aconn.cursor("foo") + assert cur.connection is aconn + await cur.close() + + +async def test_description(aconn): + cur = aconn.cursor("foo") + assert cur.name == "foo" + await cur.execute("select generate_series(1, 10)::int4 as bar") + assert len(cur.description) == 1 + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["int4"].oid + assert cur.pgresult.ntuples == 0 + await cur.close() + + +async def test_format(aconn): + cur = aconn.cursor("foo") + assert cur.format == Format.TEXT + await cur.close() + + cur = aconn.cursor("foo", binary=True) + assert cur.format == Format.BINARY + await cur.close() + + +async def test_query_params(aconn): + async with aconn.cursor("foo") as cur: + assert cur._query is None + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur._query is not None + assert b"declare" in cur._query.query.lower() + assert b"(1, $1)" in cur._query.query.lower() + assert cur._query.params == [bytes([0, 3])] # 3 as binary int2 + + +async def test_binary_cursor_execute(aconn): + cur = aconn.cursor("foo", binary=True) + await cur.execute("select generate_series(1, 2)::int4") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + await cur.close() + + +async def test_execute_binary(aconn): + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 2)::int4", binary=True) + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + + await cur.execute("select generate_series(1, 1)") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + await cur.close() + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor("foo", binary=True) + await cur.execute("select generate_series(1, 2)", binary=False) + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"2" + + await cur.execute("select generate_series(1, 2)::int4") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + await cur.close() + + +async def test_close(aconn, recwarn): + if aconn.info.transaction_status == aconn.TransactionStatus.INTRANS: + # connection dirty from previous failure + await aconn.execute("close foo") + recwarn.clear() + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 10) as bar") + await cur.close() + assert cur.closed + + assert not await ( + await aconn.execute("select * from pg_cursors where name = 'foo'") + ).fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_idempotent(aconn): + cur = aconn.cursor("foo") + await cur.execute("select 1") + await cur.fetchall() + await cur.close() + await cur.close() + + +async def test_close_broken_conn(aconn): + cur = aconn.cursor("foo") + await aconn.close() + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchall() + + +async def test_close_noop(aconn, recwarn): + recwarn.clear() + cur = aconn.cursor("foo") + await cur.close() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_on_error(aconn): + cur = aconn.cursor("foo") + await cur.execute("select 1") + with pytest.raises(e.ProgrammingError): + await aconn.execute("wat") + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + await cur.close() + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_context(aconn, recwarn): + recwarn.clear() + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, 10) as bar") + + assert cur.closed + assert not await ( + await aconn.execute("select * from pg_cursors where name = 'foo'") + ).fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_no_clobber(aconn): + with pytest.raises(e.DivisionByZero): + async with aconn.cursor("foo") as cur: + await cur.execute("select 1 / %s", (0,)) + await cur.fetchall() + + +async def test_warn_close(aconn, recwarn): + recwarn.clear() + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 10) as bar") + del cur + assert ".close()" in str(recwarn.pop(ResourceWarning).message) + + +async def test_execute_reuse(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as foo", (3,)) + assert await cur.fetchone() == (1,) + + await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world")) + assert await cur.fetchone() == ("hello", "world") + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["text"].oid + assert cur.description[1].name == "baz" + + +@pytest.mark.parametrize( + "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"] +) +async def test_execute_error(aconn, stmt): + cur = aconn.cursor("foo") + with pytest.raises(e.ProgrammingError): + await cur.execute(stmt) + await cur.close() + + +async def test_executemany(aconn): + cur = aconn.cursor("foo") + with pytest.raises(e.NotSupportedError): + await cur.executemany("select %s", [(1,), (2,)]) + await cur.close() + + +async def test_fetchone(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (2,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchone() == (2,) + assert await cur.fetchone() is None + + +async def test_fetchmany(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (5,)) + assert await cur.fetchmany(3) == [(1,), (2,), (3,)] + assert await cur.fetchone() == (4,) + assert await cur.fetchmany(3) == [(5,)] + assert await cur.fetchmany(3) == [] + + +async def test_fetchall(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert await cur.fetchall() == [] + + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchall() == [(2,), (3,)] + assert await cur.fetchall() == [] + + +async def test_nextset(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert not cur.nextset() + + +async def test_no_result(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar where false", (3,)) + assert len(cur.description) == 1 + assert (await cur.fetchall()) == [] + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_standard_row_factory(aconn, row_factory): + if row_factory == "tuple_row": + getter = lambda r: r[0] # noqa: E731 + elif row_factory == "dict_row": + getter = lambda r: r["bar"] # noqa: E731 + elif row_factory == "namedtuple_row": + getter = lambda r: r.bar # noqa: E731 + else: + assert False, row_factory + + row_factory = getattr(rows, row_factory) + async with aconn.cursor("foo", row_factory=row_factory) as cur: + await cur.execute("select generate_series(1, 5) as bar") + assert getter(await cur.fetchone()) == 1 + assert list(map(getter, await cur.fetchmany(2))) == [2, 3] + assert list(map(getter, await cur.fetchall())) == [4, 5] + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_row_factory(aconn): + n = 0 + + def my_row_factory(cur): + nonlocal n + n += 1 + return lambda values: [n] + [-v for v in values] + + cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True) + await cur.execute("select generate_series(1, 3) as x") + recs = await cur.fetchall() + await cur.scroll(0, "absolute") + while True: + rec = await cur.fetchone() + if not rec: + break + recs.append(rec) + assert recs == [[1, -1], [1, -2], [1, -3]] * 2 + + await cur.scroll(0, "absolute") + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"x": 1} + await cur.close() + + +async def test_rownumber(aconn): + cur = aconn.cursor("foo") + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + await cur.fetchall() + assert cur.rownumber == 42 + await cur.close() + + +async def test_iter(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(1,), (2,), (3,)] + + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(2,), (3,)] + + +async def test_iter_rownumber(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + async for row in cur: + assert cur.rownumber == row[0] + + +async def test_itersize(aconn, acommands): + async with aconn.cursor("foo") as cur: + assert cur.itersize == 100 + cur.itersize = 2 + await cur.execute("select generate_series(1, %s) as bar", (3,)) + acommands.popall() # flush begin and other noise + + async for rec in cur: + pass + cmds = acommands.popall() + assert len(cmds) == 2 + for cmd in cmds: + assert "fetch forward 2" in cmd.lower() + + +async def test_cant_scroll_by_default(aconn): + cur = aconn.cursor("tmp") + assert cur.scrollable is None + with pytest.raises(e.ProgrammingError): + await cur.scroll(0) + await cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_scroll(aconn): + cur = aconn.cursor("tmp", scrollable=True) + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + + with pytest.raises(ValueError): + await cur.scroll(9, mode="wat") + await cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_scrollable(aconn): + curs = aconn.cursor("foo", scrollable=True) + assert curs.scrollable is True + await curs.execute("select generate_series(0, 5)") + await curs.scroll(5) + for i in range(4, -1, -1): + await curs.scroll(-1) + assert i == (await curs.fetchone())[0] + await curs.scroll(-1) + await curs.close() + + +async def test_non_scrollable(aconn): + curs = aconn.cursor("foo", scrollable=False) + assert curs.scrollable is False + await curs.execute("select generate_series(0, 5)") + await curs.scroll(5) + with pytest.raises(e.OperationalError): + await curs.scroll(-1) + await curs.close() + + +@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}]) +async def test_no_hold(aconn, kwargs): + async with aconn.cursor("foo", **kwargs) as curs: + assert curs.withhold is False + await curs.execute("select generate_series(0, 2)") + assert await curs.fetchone() == (0,) + await aconn.commit() + with pytest.raises(e.InvalidCursorName): + await curs.fetchone() + + +@pytest.mark.crdb_skip("cursor with hold") +async def test_hold(aconn): + async with aconn.cursor("foo", withhold=True) as curs: + assert curs.withhold is True + await curs.execute("select generate_series(0, 5)") + assert await curs.fetchone() == (0,) + await aconn.commit() + assert await curs.fetchone() == (1,) + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"]) +async def test_steal_cursor(aconn, row_factory): + cur1 = aconn.cursor() + await cur1.execute( + "declare test cursor without hold for select generate_series(1, 6) as s" + ) + + cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory)) + # can call fetch without execute + rec = await cur2.fetchone() + assert rec == (1,) + if row_factory == "namedtuple_row": + assert rec.s == 1 + assert await cur2.fetchmany(3) == [(2,), (3,), (4,)] + assert await cur2.fetchall() == [(5,), (6,)] + await cur2.close() + + +async def test_stolen_cursor_close(aconn): + cur1 = aconn.cursor() + await cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = aconn.cursor("test") + await cur2.close() + + await cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = aconn.cursor("test") + await cur2.close() diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100644 index 0000000..42b6c63 --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,604 @@ +# test_sql.py - tests for the psycopg2.sql module + +# Copyright (C) 2020 The Psycopg Team + +import re +import datetime as dt + +import pytest + +from psycopg import pq, sql, ProgrammingError +from psycopg.adapt import PyFormat +from psycopg._encodings import py2pgenc +from psycopg.types import TypeInfo +from psycopg.types.string import StrDumper + +from .utils import eur +from .fix_crdb import crdb_encoding, crdb_scs_off + + +@pytest.mark.parametrize( + "obj, quoted", + [ + ("foo\\bar", " E'foo\\\\bar'"), + ("hello", "'hello'"), + (42, "42"), + (True, "true"), + (None, "NULL"), + ], +) +def test_quote(obj, quoted): + assert sql.quote(obj) == quoted + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_quote_roundtrip(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + + for i in range(1, 256): + want = chr(i) + quoted = sql.quote(want) + got = conn.execute(f"select {quoted}::text").fetchone()[0] + assert want == got + + # No "nonstandard use of \\ in a string literal" warning + assert not messages, f"error with {want!r}" + + +@pytest.mark.parametrize("dummy", [crdb_scs_off("off")]) +def test_quote_stable_despite_deranged_libpq(conn, dummy): + # Verify the libpq behaviour of PQescapeString using the last setting seen. + # Check that we are not affected by it. + good_str = " E'\\\\'" + good_bytes = " E'\\\\000'::bytea" + conn.execute("set standard_conforming_strings to on") + assert pq.Escaping().escape_string(b"\\") == b"\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\000" + assert sql.quote(b"\x00") == good_bytes + + conn.execute("set standard_conforming_strings to off") + assert pq.Escaping().escape_string(b"\\") == b"\\\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000" + assert sql.quote(b"\x00") == good_bytes + + # Verify that the good values are actually good + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute("set escape_string_warning to on") + for scs in ("on", "off"): + conn.execute(f"set standard_conforming_strings to {scs}") + cur = conn.execute(f"select {good_str}, {good_bytes}::bytea") + assert cur.fetchone() == ("\\", b"\x00") + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + +class TestSqlFormat: + def test_pos(self, conn): + s = sql.SQL("select {} from {}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_pos_spec(self, conn): + s = sql.SQL("select {0} from {1}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + s = sql.SQL("select {1} from {0}").format( + sql.Identifier("table"), sql.Identifier("field") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_dict(self, conn): + s = sql.SQL("select {f} from {t}").format( + f=sql.Identifier("field"), t=sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_compose_literal(self, conn): + s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31))) + s1 = s.as_string(conn) + assert s1 == "select '2016-12-31'::date;" + + def test_compose_empty(self, conn): + s = sql.SQL("select foo;").format() + s1 = s.as_string(conn) + assert s1 == "select foo;" + + def test_percent_escape(self, conn): + s = sql.SQL("42 % {0}").format(sql.Literal(7)) + s1 = s.as_string(conn) + assert s1 == "42 % 7" + + def test_braces_escape(self, conn): + s = sql.SQL("{{{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{7}" + s = sql.SQL("{{1,{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{1,7}" + + def test_compose_badnargs(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format() + + def test_compose_badnargs_auto(self): + with pytest.raises(IndexError): + sql.SQL("select {};").format() + with pytest.raises(ValueError): + sql.SQL("select {} {1};").format(10, 20) + with pytest.raises(ValueError): + sql.SQL("select {0} {};").format(10, 20) + + def test_compose_bad_args_type(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format(a=10) + with pytest.raises(KeyError): + sql.SQL("select {x};").format(10) + + def test_no_modifiers(self): + with pytest.raises(ValueError): + sql.SQL("select {a!r};").format(a=10) + with pytest.raises(ValueError): + sql.SQL("select {a:<};").format(a=10) + + def test_must_be_adaptable(self, conn): + class Foo: + pass + + s = sql.SQL("select {0};").format(sql.Literal(Foo())) + with pytest.raises(ProgrammingError): + s.as_string(conn) + + def test_auto_literal(self, conn): + s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1)) + assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'::date" + + def test_execute(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.execute( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])), + (sql.Placeholder() * 3).join(", "), + ), + (10, "a", "b", "c"), + ) + + cur.execute("select * from test_compose") + assert cur.fetchall() == [(10, "a", "b", "c")] + + def test_executemany(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.executemany( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])), + (sql.Placeholder() * 3).join(", "), + ), + [(10, "a", "b", "c"), (20, "d", "e", "f")], + ) + + cur.execute("select * from test_compose") + assert cur.fetchall() == [(10, "a", "b", "c"), (20, "d", "e", "f")] + + @pytest.mark.crdb_skip("copy") + def test_copy(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + + with cur.copy( + sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ), + ) as copy: + copy.write_row((10, "a", "b", "c")) + copy.write_row((20, "d", "e", "f")) + + with cur.copy( + sql.SQL("copy (select {f} from {t} order by id) to stdout").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ) + ) as copy: + assert list(copy) == [b"c\n", b"f\n"] + + +class TestIdentifier: + def test_class(self): + assert issubclass(sql.Identifier, sql.Composable) + + def test_init(self): + assert isinstance(sql.Identifier("foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier) + with pytest.raises(TypeError): + sql.Identifier() + with pytest.raises(TypeError): + sql.Identifier(10) # type: ignore[arg-type] + with pytest.raises(TypeError): + sql.Identifier(dt.date(2016, 12, 31)) # type: ignore[arg-type] + + def test_repr(self): + obj = sql.Identifier("fo'o") + assert repr(obj) == 'Identifier("fo\'o")' + assert repr(obj) == str(obj) + + obj = sql.Identifier("fo'o", 'ba"r') + assert repr(obj) == "Identifier(\"fo'o\", 'ba\"r')" + assert repr(obj) == str(obj) + + def test_eq(self): + assert sql.Identifier("foo") == sql.Identifier("foo") + assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar") + assert sql.Identifier("foo") != sql.Identifier("bar") + assert sql.Identifier("foo") != "foo" + assert sql.Identifier("foo") != sql.SQL("foo") + + @pytest.mark.parametrize( + "args, want", + [ + (("foo",), '"foo"'), + (("foo", "bar"), '"foo"."bar"'), + (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), + ], + ) + def test_as_string(self, conn, args, want): + assert sql.Identifier(*args).as_string(conn) == want + + @pytest.mark.parametrize( + "args, want, enc", + [ + crdb_encoding(("foo",), '"foo"', "ascii"), + crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"), + crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), + (("foo", eur), f'"foo"."{eur}"', "utf8"), + crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"), + ], + ) + def test_as_bytes(self, conn, args, want, enc): + want = want.encode(enc) + conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}") + assert sql.Identifier(*args).as_bytes(conn) == want + + def test_join(self): + assert not hasattr(sql.Identifier("foo"), "join") + + +class TestLiteral: + def test_class(self): + assert issubclass(sql.Literal, sql.Composable) + + def test_init(self): + assert isinstance(sql.Literal("foo"), sql.Literal) + assert isinstance(sql.Literal("foo"), sql.Literal) + assert isinstance(sql.Literal(b"foo"), sql.Literal) + assert isinstance(sql.Literal(42), sql.Literal) + assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal) + + def test_repr(self): + assert repr(sql.Literal("foo")) == "Literal('foo')" + assert str(sql.Literal("foo")) == "Literal('foo')" + + def test_as_string(self, conn): + assert sql.Literal(None).as_string(conn) == "NULL" + assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'" + assert sql.Literal(42).as_string(conn) == "42" + assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date" + + def test_as_bytes(self, conn): + assert sql.Literal(None).as_bytes(conn) == b"NULL" + assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'" + assert sql.Literal(42).as_bytes(conn) == b"42" + assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes_encoding(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode(encoding) + + def test_eq(self): + assert sql.Literal("foo") == sql.Literal("foo") + assert sql.Literal("foo") != sql.Literal("bar") + assert sql.Literal("foo") != "foo" + assert sql.Literal("foo") != sql.SQL("foo") + + def test_must_be_adaptable(self, conn): + class Foo: + pass + + with pytest.raises(ProgrammingError): + sql.Literal(Foo()).as_string(conn) + + def test_array(self, conn): + assert ( + sql.Literal([dt.date(2000, 1, 1)]).as_string(conn) + == "'{2000-01-01}'::date[]" + ) + + def test_short_name_builtin(self, conn): + assert sql.Literal(dt.time(0, 0)).as_string(conn) == "'00:00:00'::time" + assert ( + sql.Literal(dt.datetime(2000, 1, 1)).as_string(conn) + == "'2000-01-01 00:00:00'::timestamp" + ) + assert ( + sql.Literal([dt.datetime(2000, 1, 1)]).as_string(conn) + == "'{\"2000-01-01 00:00:00\"}'::timestamp[]" + ) + + def test_text_literal(self, conn): + conn.adapters.register_dumper(str, StrDumper) + assert sql.Literal("foo").as_string(conn) == "'foo'" + + @pytest.mark.crdb_skip("composite") # create type, actually + @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"]) + def test_invalid_name(self, conn, name): + conn.execute( + f""" + set client_encoding to utf8; + create type "{name}"; + create function invin(cstring) returns "{name}" + language internal immutable strict as 'textin'; + create function invout("{name}") returns cstring + language internal immutable strict as 'textout'; + create type "{name}" (input=invin, output=invout, like=text); + """ + ) + info = TypeInfo.fetch(conn, f'"{name}"') + + class InvDumper(StrDumper): + oid = info.oid + + def dump(self, obj): + rv = super().dump(obj) + return b"%s-inv" % rv + + info.register(conn) + conn.adapters.register_dumper(str, InvDumper) + + assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format("hello")) + assert cur.fetchone()[0] == "hello-inv" + + assert ( + sql.Literal(["hello"]).as_string(conn) == f"'{{hello-inv}}'::\"{name}\"[]" + ) + cur = conn.execute(sql.SQL("select {}").format(["hello"])) + assert cur.fetchone()[0] == ["hello-inv"] + + +class TestSQL: + def test_class(self): + assert issubclass(sql.SQL, sql.Composable) + + def test_init(self): + assert isinstance(sql.SQL("foo"), sql.SQL) + assert isinstance(sql.SQL("foo"), sql.SQL) + with pytest.raises(TypeError): + sql.SQL(10) # type: ignore[arg-type] + with pytest.raises(TypeError): + sql.SQL(dt.date(2016, 12, 31)) # type: ignore[arg-type] + + def test_repr(self, conn): + assert repr(sql.SQL("foo")) == "SQL('foo')" + assert str(sql.SQL("foo")) == "SQL('foo')" + assert sql.SQL("foo").as_string(conn) == "foo" + + def test_eq(self): + assert sql.SQL("foo") == sql.SQL("foo") + assert sql.SQL("foo") != sql.SQL("bar") + assert sql.SQL("foo") != "foo" + assert sql.SQL("foo") != sql.Literal("foo") + + def test_sum(self, conn): + obj = sql.SQL("foo") + sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_sum_inplace(self, conn): + obj = sql.SQL("f") + sql.SQL("oo") + obj += sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_multiply(self, conn): + obj = sql.SQL("foo") * 3 + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foofoofoo" + + def test_join(self, conn): + obj = sql.SQL(", ").join( + [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)] + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == '"foo", bar, 42' + + obj = sql.SQL(", ").join( + sql.Composed([sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]) + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == '"foo", bar, 42' + + obj = sql.SQL(", ").join([]) + assert obj == sql.Composed([]) + + def test_as_string(self, conn): + assert sql.SQL("foo").as_string(conn) == "foo" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes(self, conn, encoding): + if encoding: + conn.execute(f"set client_encoding to {encoding}") + + assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding) + + +class TestComposed: + def test_class(self): + assert issubclass(sql.Composed, sql.Composable) + + def test_repr(self): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + assert repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])""" + assert str(obj) == repr(obj) + + def test_eq(self): + L = [sql.Literal("foo"), sql.Identifier("b'ar")] + l2 = [sql.Literal("foo"), sql.Literal("b'ar")] + assert sql.Composed(L) == sql.Composed(list(L)) + assert sql.Composed(L) != L + assert sql.Composed(L) != sql.Composed(l2) + + def test_join(self, conn): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + obj = obj.join(", ") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "'foo', \"b'ar\"" + + def test_auto_literal(self, conn): + obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)]) + obj = obj.join(", ") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'::date" + + def test_sum(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj = obj + sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + def test_sum_inplace(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Composed([sql.Literal("bar")]) + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + def test_iter(self): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + it = iter(obj) + i = next(it) + assert i == sql.SQL("foo") + i = next(it) + assert i == sql.SQL("bar") + with pytest.raises(StopIteration): + next(it) + + def test_as_string(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_string(conn) == "foobar" + + def test_as_bytes(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_bytes(conn) == b"foobar" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes_encoding(self, conn, encoding): + obj = sql.Composed([sql.SQL("foo"), sql.SQL(eur)]) + conn.execute(f"set client_encoding to {encoding}") + assert obj.as_bytes(conn) == ("foo" + eur).encode(encoding) + + +class TestPlaceholder: + def test_class(self): + assert issubclass(sql.Placeholder, sql.Composable) + + @pytest.mark.parametrize("format", PyFormat) + def test_repr_format(self, conn, format): + ph = sql.Placeholder(format=format) + add = f"format={format.name}" if format != PyFormat.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder({add})" + + @pytest.mark.parametrize("format", PyFormat) + def test_repr_name_format(self, conn, format): + ph = sql.Placeholder("foo", format=format) + add = f", format={format.name}" if format != PyFormat.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder('foo'{add})" + + def test_bad_name(self): + with pytest.raises(ValueError): + sql.Placeholder(")") + + def test_eq(self): + assert sql.Placeholder("foo") == sql.Placeholder("foo") + assert sql.Placeholder("foo") != sql.Placeholder("bar") + assert sql.Placeholder("foo") != "foo" + assert sql.Placeholder() == sql.Placeholder() + assert sql.Placeholder("foo") != sql.Placeholder() + assert sql.Placeholder("foo") != sql.Literal("foo") + + @pytest.mark.parametrize("format", PyFormat) + def test_as_string(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_string(conn) == f"%{format.value}" + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_string(conn) == f"%(foo){format.value}" + + @pytest.mark.parametrize("format", PyFormat) + def test_as_bytes(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii") + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii") + + +class TestValues: + def test_null(self, conn): + assert isinstance(sql.NULL, sql.SQL) + assert sql.NULL.as_string(conn) == "NULL" + + def test_default(self, conn): + assert isinstance(sql.DEFAULT, sql.SQL) + assert sql.DEFAULT.as_string(conn) == "DEFAULT" + + +def no_e(s): + """Drop an eventual E from E'' quotes""" + if isinstance(s, memoryview): + s = bytes(s) + + if isinstance(s, str): + return re.sub(r"\bE'", "'", s) + elif isinstance(s, bytes): + return re.sub(rb"\bE'", b"'", s) + else: + raise TypeError(f"not dealing with {type(s).__name__}: {s}") diff --git a/tests/test_tpc.py b/tests/test_tpc.py new file mode 100644 index 0000000..91a04e0 --- /dev/null +++ b/tests/test_tpc.py @@ -0,0 +1,325 @@ +import pytest + +import psycopg +from psycopg.pq import TransactionStatus + +pytestmark = pytest.mark.crdb_skip("2-phase commit") + + +def test_tpc_disabled(conn, pipeline): + val = int(conn.execute("show max_prepared_transactions").fetchone()[0]) + if val: + pytest.skip("prepared transactions enabled") + + conn.rollback() + conn.tpc_begin("x") + with pytest.raises(psycopg.NotSupportedError): + conn.tpc_prepare() + + +class TestTPC: + def test_tpc_commit(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_commit_one_phase(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_commit_recovered(self, conn_cls, conn, dsn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + conn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + with conn_cls.connect(dsn) as conn: + xid = conn.xid(1, "gtrid", "bqual") + conn.tpc_commit(xid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_rollback(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_rollback() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_tpc_rollback_one_phase(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_rollback() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_tpc_rollback_recovered(self, conn_cls, conn, dsn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + conn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + with conn_cls.connect(dsn) as conn: + xid = conn.xid(1, "gtrid", "bqual") + conn.tpc_rollback(xid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_status_after_recover(self, conn, tpc): + assert conn.info.transaction_status == TransactionStatus.IDLE + conn.tpc_recover() + assert conn.info.transaction_status == TransactionStatus.IDLE + + cur = conn.cursor() + cur.execute("select 1") + assert conn.info.transaction_status == TransactionStatus.INTRANS + conn.tpc_recover() + assert conn.info.transaction_status == TransactionStatus.INTRANS + + def test_recovered_xids(self, conn, tpc): + # insert a few test xns + conn.autocommit = True + cur = conn.cursor() + cur.execute("begin; prepare transaction '1-foo'") + cur.execute("begin; prepare transaction '2-bar'") + + # read the values to return + cur.execute( + """ + select gid, prepared, owner, database from pg_prepared_xacts + where database = %s + """, + (conn.info.dbname,), + ) + okvals = cur.fetchall() + okvals.sort() + + xids = conn.tpc_recover() + xids = [xid for xid in xids if xid.database == conn.info.dbname] + xids.sort(key=lambda x: x.gtrid) + + # check the values returned + assert len(okvals) == len(xids) + for (xid, (gid, prepared, owner, database)) in zip(xids, okvals): + assert xid.gtrid == gid + assert xid.prepared == prepared + assert xid.owner == owner + assert xid.database == database + + def test_xid_encoding(self, conn, tpc): + xid = conn.xid(42, "gtrid", "bqual") + conn.tpc_begin(xid) + conn.tpc_prepare() + + cur = conn.cursor() + cur.execute( + "select gid from pg_prepared_xacts where database = %s", + (conn.info.dbname,), + ) + assert "42_Z3RyaWQ=_YnF1YWw=" == cur.fetchone()[0] + + @pytest.mark.parametrize( + "fid, gtrid, bqual", + [ + (0, "", ""), + (42, "gtrid", "bqual"), + (0x7FFFFFFF, "x" * 64, "y" * 64), + ], + ) + def test_xid_roundtrip(self, conn_cls, conn, dsn, tpc, fid, gtrid, bqual): + xid = conn.xid(fid, gtrid, bqual) + conn.tpc_begin(xid) + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname] + + assert len(xids) == 1 + xid = xids[0] + conn.tpc_rollback(xid) + + assert xid.format_id == fid + assert xid.gtrid == gtrid + assert xid.bqual == bqual + + @pytest.mark.parametrize( + "tid", + [ + "", + "hello, world!", + "x" * 199, # PostgreSQL's limit in transaction id length + ], + ) + def test_unparsed_roundtrip(self, conn_cls, conn, dsn, tpc, tid): + conn.tpc_begin(tid) + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname] + + assert len(xids) == 1 + xid = xids[0] + conn.tpc_rollback(xid) + + assert xid.format_id is None + assert xid.gtrid == tid + assert xid.bqual is None + + def test_xid_unicode(self, conn_cls, conn, dsn, tpc): + x1 = conn.xid(10, "uni", "code") + conn.tpc_begin(x1) + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0] + assert 10 == xid.format_id + assert "uni" == xid.gtrid + assert "code" == xid.bqual + + def test_xid_unicode_unparsed(self, conn_cls, conn, dsn, tpc): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check unicode is accepted as type. + conn.execute("set client_encoding to utf8") + conn.commit() + + conn.tpc_begin("transaction-id") + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0] + + assert xid.format_id is None + assert xid.gtrid == "transaction-id" + assert xid.bqual is None + + def test_cancel_fails_prepared(self, conn, tpc): + conn.tpc_begin("cancel") + conn.tpc_prepare() + with pytest.raises(psycopg.ProgrammingError): + conn.cancel() + + def test_tpc_recover_non_dbapi_connection(self, conn_cls, conn, dsn, tpc): + conn.row_factory = psycopg.rows.dict_row + conn.tpc_begin("dict-connection") + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xids = conn.tpc_recover() + xid = [x for x in xids if x.database == conn.info.dbname][0] + + assert xid.format_id is None + assert xid.gtrid == "dict-connection" + assert xid.bqual is None + + +class TestXidObject: + def test_xid_construction(self): + x1 = psycopg.Xid(74, "foo", "bar") + 74 == x1.format_id + "foo" == x1.gtrid + "bar" == x1.bqual + + def test_xid_from_string(self): + x2 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=") + 42 == x2.format_id + "gtrid" == x2.gtrid + "bqual" == x2.bqual + + x3 = psycopg.Xid.from_string("99_xxx_yyy") + None is x3.format_id + "99_xxx_yyy" == x3.gtrid + None is x3.bqual + + def test_xid_to_string(self): + x1 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=") + str(x1) == "42_Z3RyaWQ=_YnF1YWw=" + + x2 = psycopg.Xid.from_string("99_xxx_yyy") + str(x2) == "99_xxx_yyy" diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py new file mode 100644 index 0000000..a409a2e --- /dev/null +++ b/tests/test_tpc_async.py @@ -0,0 +1,310 @@ +import pytest + +import psycopg +from psycopg.pq import TransactionStatus + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("2-phase commit"), +] + + +async def test_tpc_disabled(aconn, apipeline): + cur = await aconn.execute("show max_prepared_transactions") + val = int((await cur.fetchone())[0]) + if val: + pytest.skip("prepared transactions enabled") + + await aconn.rollback() + await aconn.tpc_begin("x") + with pytest.raises(psycopg.NotSupportedError): + await aconn.tpc_prepare() + + +class TestTPC: + async def test_tpc_commit(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + await aconn.tpc_commit() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_commit_one_phase(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_commit() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_commit_recovered(self, aconn_cls, aconn, dsn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + await aconn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + async with await aconn_cls.connect(dsn) as aconn: + xid = aconn.xid(1, "gtrid", "bqual") + await aconn.tpc_commit(xid) + assert aconn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_rollback(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_rollback')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + await aconn.tpc_rollback() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_tpc_rollback_one_phase(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_rollback() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_tpc_rollback_recovered(self, aconn_cls, aconn, dsn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + await aconn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + async with await aconn_cls.connect(dsn) as aconn: + xid = aconn.xid(1, "gtrid", "bqual") + await aconn.tpc_rollback(xid) + assert aconn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_status_after_recover(self, aconn, tpc): + assert aconn.info.transaction_status == TransactionStatus.IDLE + await aconn.tpc_recover() + assert aconn.info.transaction_status == TransactionStatus.IDLE + + cur = aconn.cursor() + await cur.execute("select 1") + assert aconn.info.transaction_status == TransactionStatus.INTRANS + await aconn.tpc_recover() + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + async def test_recovered_xids(self, aconn, tpc): + # insert a few test xns + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("begin; prepare transaction '1-foo'") + await cur.execute("begin; prepare transaction '2-bar'") + + # read the values to return + await cur.execute( + """ + select gid, prepared, owner, database from pg_prepared_xacts + where database = %s + """, + (aconn.info.dbname,), + ) + okvals = await cur.fetchall() + okvals.sort() + + xids = await aconn.tpc_recover() + xids = [xid for xid in xids if xid.database == aconn.info.dbname] + xids.sort(key=lambda x: x.gtrid) + + # check the values returned + assert len(okvals) == len(xids) + for (xid, (gid, prepared, owner, database)) in zip(xids, okvals): + assert xid.gtrid == gid + assert xid.prepared == prepared + assert xid.owner == owner + assert xid.database == database + + async def test_xid_encoding(self, aconn, tpc): + xid = aconn.xid(42, "gtrid", "bqual") + await aconn.tpc_begin(xid) + await aconn.tpc_prepare() + + cur = aconn.cursor() + await cur.execute( + "select gid from pg_prepared_xacts where database = %s", + (aconn.info.dbname,), + ) + assert "42_Z3RyaWQ=_YnF1YWw=" == (await cur.fetchone())[0] + + @pytest.mark.parametrize( + "fid, gtrid, bqual", + [ + (0, "", ""), + (42, "gtrid", "bqual"), + (0x7FFFFFFF, "x" * 64, "y" * 64), + ], + ) + async def test_xid_roundtrip(self, aconn_cls, aconn, dsn, tpc, fid, gtrid, bqual): + xid = aconn.xid(fid, gtrid, bqual) + await aconn.tpc_begin(xid) + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xids = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + await aconn.tpc_rollback(xid) + + assert xid.format_id == fid + assert xid.gtrid == gtrid + assert xid.bqual == bqual + + @pytest.mark.parametrize( + "tid", + [ + "", + "hello, world!", + "x" * 199, # PostgreSQL's limit in transaction id length + ], + ) + async def test_unparsed_roundtrip(self, aconn_cls, aconn, dsn, tpc, tid): + await aconn.tpc_begin(tid) + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xids = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + await aconn.tpc_rollback(xid) + + assert xid.format_id is None + assert xid.gtrid == tid + assert xid.bqual is None + + async def test_xid_unicode(self, aconn_cls, aconn, dsn, tpc): + x1 = aconn.xid(10, "uni", "code") + await aconn.tpc_begin(x1) + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xid = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ][0] + + assert 10 == xid.format_id + assert "uni" == xid.gtrid + assert "code" == xid.bqual + + async def test_xid_unicode_unparsed(self, aconn_cls, aconn, dsn, tpc): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check unicode is accepted as type. + await aconn.execute("set client_encoding to utf8") + await aconn.commit() + + await aconn.tpc_begin("transaction-id") + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xid = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ][0] + + assert xid.format_id is None + assert xid.gtrid == "transaction-id" + assert xid.bqual is None + + async def test_cancel_fails_prepared(self, aconn, tpc): + await aconn.tpc_begin("cancel") + await aconn.tpc_prepare() + with pytest.raises(psycopg.ProgrammingError): + aconn.cancel() + + async def test_tpc_recover_non_dbapi_connection(self, aconn_cls, aconn, dsn, tpc): + aconn.row_factory = psycopg.rows.dict_row + await aconn.tpc_begin("dict-connection") + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xids = await aconn.tpc_recover() + xid = [x for x in xids if x.database == aconn.info.dbname][0] + + assert xid.format_id is None + assert xid.gtrid == "dict-connection" + assert xid.bqual is None diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 0000000..9391e00 --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,796 @@ +import sys +import logging +from threading import Thread, Event + +import pytest + +import psycopg +from psycopg import Rollback +from psycopg import errors as e + +# TODOCRDB: is this the expected behaviour? +crdb_skip_external_observer = pytest.mark.crdb( + "skip", reason="deadlock on observer connection" +) + + +@pytest.fixture +def conn(conn, pipeline): + return conn + + +@pytest.fixture(autouse=True) +def create_test_table(svcconn): + """Creates a table called 'test_table' for use in tests.""" + cur = svcconn.cursor() + cur.execute("drop table if exists test_table") + cur.execute("create table test_table (id text primary key)") + yield + cur.execute("drop table test_table") + + +def insert_row(conn, value): + sql = "INSERT INTO test_table VALUES (%s)" + if isinstance(conn, psycopg.Connection): + conn.cursor().execute(sql, (value,)) + else: + + async def f(): + cur = conn.cursor() + await cur.execute(sql, (value,)) + + return f() + + +def inserted(conn): + """Return the values inserted in the test table.""" + sql = "SELECT * FROM test_table" + if isinstance(conn, psycopg.Connection): + rows = conn.cursor().execute(sql).fetchall() + return set(v for (v,) in rows) + else: + + async def f(): + cur = conn.cursor() + await cur.execute(sql) + rows = await cur.fetchall() + return set(v for (v,) in rows) + + return f() + + +def in_transaction(conn): + if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE: + return False + elif conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS: + return True + else: + assert False, conn.pgconn.transaction_status + + +def get_exc_info(exc): + """Return the exc info for an exception or a success if exc is None""" + if not exc: + return (None,) * 3 + try: + raise exc + except exc: + return sys.exc_info() + + +class ExpectedException(Exception): + pass + + +def test_basic(conn, pipeline): + """Basic use of transaction() to BEGIN and COMMIT a transaction.""" + assert not in_transaction(conn) + with conn.transaction(): + if pipeline: + pipeline.sync() + assert in_transaction(conn) + assert not in_transaction(conn) + + +def test_exposes_associated_connection(conn): + """Transaction exposes its connection as a read-only property.""" + with conn.transaction() as tx: + assert tx.connection is conn + with pytest.raises(AttributeError): + tx.connection = conn + + +def test_exposes_savepoint_name(conn): + """Transaction exposes its savepoint name as a read-only property.""" + with conn.transaction(savepoint_name="foo") as tx: + assert tx.savepoint_name == "foo" + with pytest.raises(AttributeError): + tx.savepoint_name = "bar" + + +def test_cant_reenter(conn): + with conn.transaction() as tx: + pass + + with pytest.raises(TypeError): + with tx: + pass + + +def test_begins_on_enter(conn, pipeline): + """Transaction does not begin until __enter__() is called.""" + tx = conn.transaction() + assert not in_transaction(conn) + with tx: + if pipeline: + pipeline.sync() + assert in_transaction(conn) + assert not in_transaction(conn) + + +def test_commit_on_successful_exit(conn): + """Changes are committed on successful exit from the `with` block.""" + with conn.transaction(): + insert_row(conn, "foo") + + assert not in_transaction(conn) + assert inserted(conn) == {"foo"} + + +def test_rollback_on_exception_exit(conn): + """Changes are rolled back if an exception escapes the `with` block.""" + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "foo") + raise ExpectedException("This discards the insert") + + assert not in_transaction(conn) + assert not inserted(conn) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_context_inerror_rollback_no_clobber(conn_cls, conn, pipeline, dsn, caplog): + if pipeline: + # Only 'conn' is possibly in pipeline mode, but the transaction and + # checks are on 'conn2'. + pytest.skip("not applicable") + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + with conn_cls.connect(dsn) as conn2: + with conn2.transaction(): + conn2.execute("select 1") + conn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + conn = conn_cls.connect(dsn) + try: + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + finally: + conn.close() + + +def test_interaction_dbapi_transaction(conn): + insert_row(conn, "foo") + + with conn.transaction(): + insert_row(conn, "bar") + raise Rollback + + with conn.transaction(): + insert_row(conn, "baz") + + assert in_transaction(conn) + conn.commit() + assert inserted(conn) == {"foo", "baz"} + + +def test_prohibits_use_of_commit_rollback_autocommit(conn): + """ + Within a Transaction block, it is forbidden to touch commit, rollback, + or the autocommit setting on the connection, as this would interfere + with the transaction scope being managed by the Transaction block. + """ + conn.autocommit = False + conn.commit() + conn.rollback() + + with conn.transaction(): + with pytest.raises(e.ProgrammingError): + conn.autocommit = False + with pytest.raises(e.ProgrammingError): + conn.commit() + with pytest.raises(e.ProgrammingError): + conn.rollback() + + conn.autocommit = False + conn.commit() + conn.rollback() + + +@pytest.mark.parametrize("autocommit", [False, True]) +def test_preserves_autocommit(conn, autocommit): + """ + Connection.autocommit is unchanged both during and after Transaction block. + """ + conn.autocommit = autocommit + with conn.transaction(): + assert conn.autocommit is autocommit + assert conn.autocommit is autocommit + + +def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are committed + """ + conn.autocommit = False + assert not in_transaction(conn) + with conn.transaction(): + insert_row(conn, "new") + assert not in_transaction(conn) + + # Changes committed + assert inserted(conn) == {"new"} + assert inserted(svcconn) == {"new"} + + +def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made within Transaction context are discarded + """ + conn.autocommit = False + assert not in_transaction(conn) + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "new") + raise ExpectedException() + assert not in_transaction(conn) + + # Changes discarded + assert not inserted(conn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +def test_autocommit_off_and_tx_in_progress_successful_exit(conn, pipeline, svcconn): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are left intact + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + conn.autocommit = False + insert_row(conn, "prior") + if pipeline: + pipeline.sync() + assert in_transaction(conn) + with conn.transaction(): + insert_row(conn, "new") + assert in_transaction(conn) + assert inserted(conn) == {"prior", "new"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +@crdb_skip_external_observer +def test_autocommit_off_and_tx_in_progress_exception_exit(conn, pipeline, svcconn): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made before the Transaction context are left intact + * Changes made within Transaction context are discarded + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + conn.autocommit = False + insert_row(conn, "prior") + if pipeline: + pipeline.sync() + assert in_transaction(conn) + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "new") + raise ExpectedException() + assert in_transaction(conn) + assert inserted(conn) == {"prior"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn): + """Changes from nested transaction contexts are all persisted on exit.""" + with conn.transaction(): + insert_row(conn, "outer-before") + with conn.transaction(): + insert_row(conn, "inner") + insert_row(conn, "outer-after") + assert not in_transaction(conn) + assert inserted(conn) == {"outer-before", "inner", "outer-after"} + assert inserted(svcconn) == {"outer-before", "inner", "outer-after"} + + +def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in outer context escapes. + """ + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + assert not in_transaction(conn) + assert not inserted(conn) + assert not inserted(svcconn) + + +def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in inner context escapes the outer context. + """ + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + assert not in_transaction(conn) + assert not inserted(conn) + assert not inserted(svcconn) + + +def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn): + """ + An exception escaping the inner transaction context causes changes made + within that inner context to be discarded, but the error can then be + handled in the outer context, allowing changes made in the outer context + (both before, and after, the inner context) to be successfully committed. + """ + with conn.transaction(): + insert_row(conn, "outer-before") + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + insert_row(conn, "outer-after") + assert not in_transaction(conn) + assert inserted(conn) == {"outer-before", "outer-after"} + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +def test_nested_three_levels_successful_exit(conn, svcconn): + """Exercise management of more than one savepoint.""" + with conn.transaction(): # BEGIN + insert_row(conn, "one") + with conn.transaction(): # SAVEPOINT s1 + insert_row(conn, "two") + with conn.transaction(): # SAVEPOINT s2 + insert_row(conn, "three") + assert not in_transaction(conn) + assert inserted(conn) == {"one", "two", "three"} + assert inserted(svcconn) == {"one", "two", "three"} + + +def test_named_savepoint_escapes_savepoint_name(conn): + with conn.transaction("s-1"): + pass + with conn.transaction("s1; drop table students"): + pass + + +def test_named_savepoints_successful_exit(conn, commands): + """ + Entering a transaction context will do one of these these things: + 1. Begin an outer transaction (if one isn't already in progress) + 2. Begin an outer transaction and create a savepoint (if one is named) + 3. Create a savepoint (if a transaction is already in progress) + either using the name provided, or auto-generating a savepoint name. + + ...and exiting the context successfully will "commit" the same. + """ + # Case 1 + # Using Transaction explicitly because conn.transaction() enters the contetx + assert not commands + with conn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + assert commands.popall() == ["COMMIT"] + + # Case 1 (with a transaction already started) + conn.cursor().execute("select 1") + assert commands.popall() == ["BEGIN"] + with conn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_1"'] + assert tx.savepoint_name == "_pg3_1" + assert commands.popall() == ['RELEASE "_pg3_1"'] + conn.rollback() + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + with conn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name provided) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with conn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + assert commands.popall() == ['RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with conn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + assert commands.popall() == ['RELEASE "_pg3_2"'] + assert commands.popall() == ["COMMIT"] + + +def test_named_savepoints_exception_exit(conn, commands): + """ + Same as the previous test but checks that when exiting the context with an + exception, whatever transaction and/or savepoint was started on enter will + be rolled-back as appropriate. + """ + # Case 1 + with pytest.raises(ExpectedException): + with conn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + with pytest.raises(ExpectedException): + with conn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 3 (with savepoint name provided) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + with conn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + raise ExpectedException + assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + with conn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + raise ExpectedException + assert commands.popall() == [ + 'ROLLBACK TO "_pg3_2"', + 'RELEASE "_pg3_2"', + ] + assert commands.popall() == ["COMMIT"] + + +def test_named_savepoints_with_repeated_names_works(conn): + """ + Using the same savepoint name repeatedly works correctly, but bypasses + some sanity checks. + """ + # Works correctly if no inner transactions are rolled back + with conn.transaction(force_rollback=True): + with conn.transaction("sp"): + insert_row(conn, "tx1") + with conn.transaction("sp"): + insert_row(conn, "tx2") + with conn.transaction("sp"): + insert_row(conn, "tx3") + assert inserted(conn) == {"tx1", "tx2", "tx3"} + + # Works correctly if one level of inner transaction is rolled back + with conn.transaction(force_rollback=True): + with conn.transaction("s1"): + insert_row(conn, "tx1") + with conn.transaction("s1", force_rollback=True): + insert_row(conn, "tx2") + with conn.transaction("s1"): + insert_row(conn, "tx3") + assert inserted(conn) == {"tx1"} + assert inserted(conn) == {"tx1"} + + # Works correctly if multiple inner transactions are rolled back + # (This scenario mandates releasing savepoints after rolling back to them.) + with conn.transaction(force_rollback=True): + with conn.transaction("s1"): + insert_row(conn, "tx1") + with conn.transaction("s1") as tx2: + insert_row(conn, "tx2") + with conn.transaction("s1"): + insert_row(conn, "tx3") + raise Rollback(tx2) + assert inserted(conn) == {"tx1"} + assert inserted(conn) == {"tx1"} + + +def test_force_rollback_successful_exit(conn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with conn.transaction(force_rollback=True): + insert_row(conn, "foo") + assert not inserted(conn) + assert not inserted(svcconn) + + +def test_force_rollback_exception_exit(conn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with pytest.raises(ExpectedException): + with conn.transaction(force_rollback=True): + insert_row(conn, "foo") + raise ExpectedException() + assert not inserted(conn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +def test_explicit_rollback_discards_changes(conn, svcconn): + """ + Raising a Rollback exception in the middle of a block exits the block and + discards all changes made within that block. + + You can raise any of the following: + - Rollback (type) + - Rollback() (instance) + - Rollback(tx) (instance initialised with reference to the transaction) + All of these are equivalent. + """ + + def assert_no_rows(): + assert not inserted(conn) + assert not inserted(svcconn) + + with conn.transaction(): + insert_row(conn, "foo") + raise Rollback + assert_no_rows() + + with conn.transaction(): + insert_row(conn, "foo") + raise Rollback() + assert_no_rows() + + with conn.transaction() as tx: + insert_row(conn, "foo") + raise Rollback(tx) + assert_no_rows() + + +@crdb_skip_external_observer +def test_explicit_rollback_outer_tx_unaffected(conn, svcconn): + """ + Raising a Rollback exception in the middle of a block does not impact an + enclosing transaction block. + """ + with conn.transaction(): + insert_row(conn, "before") + with conn.transaction(): + insert_row(conn, "during") + raise Rollback + assert in_transaction(conn) + assert not inserted(svcconn) + insert_row(conn, "after") + assert inserted(conn) == {"before", "after"} + assert inserted(svcconn) == {"before", "after"} + + +def test_explicit_rollback_of_outer_transaction(conn): + """ + Raising a Rollback exception that references an outer transaction will + discard all changes from both inner and outer transaction blocks. + """ + with conn.transaction() as outer_tx: + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise Rollback(outer_tx) + assert False, "This line of code should be unreachable." + assert not inserted(conn) + + +@crdb_skip_external_observer +def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn): + """ + Rolling-back an enclosing transaction does not impact an outer transaction. + """ + with conn.transaction(): + insert_row(conn, "outer-before") + with conn.transaction() as tx_enclosing: + insert_row(conn, "enclosing") + with conn.transaction(): + insert_row(conn, "inner") + raise Rollback(tx_enclosing) + insert_row(conn, "outer-after") + + assert inserted(conn) == {"outer-before", "outer-after"} + assert not inserted(svcconn) # Not yet committed + # Changes committed + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +def test_str(conn, pipeline): + with conn.transaction() as tx: + if pipeline: + assert "[INTRANS, pipeline=ON]" in str(tx) + else: + assert "[INTRANS]" in str(tx) + assert "(active)" in str(tx) + assert "'" not in str(tx) + with conn.transaction("wat") as tx2: + if pipeline: + assert "[INTRANS, pipeline=ON]" in str(tx2) + else: + assert "[INTRANS]" in str(tx2) + assert "'wat'" in str(tx2) + + if pipeline: + assert "[IDLE, pipeline=ON]" in str(tx) + else: + assert "[IDLE]" in str(tx) + assert "(terminated)" in str(tx) + + with pytest.raises(ZeroDivisionError): + with conn.transaction() as tx: + 1 / 0 + + assert "(terminated)" in str(tx) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +def test_out_of_order_exit(conn, exit_error): + conn.autocommit = True + + t1 = conn.transaction() + t1.__enter__() + + t2 = conn.transaction() + t2.__enter__() + + with pytest.raises(e.ProgrammingError): + t1.__exit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + t2.__exit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +def test_out_of_order_implicit_begin(conn, exit_error): + conn.execute("select 1") + + t1 = conn.transaction() + t1.__enter__() + + t2 = conn.transaction() + t2.__enter__() + + with pytest.raises(e.ProgrammingError): + t1.__exit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + t2.__exit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +def test_out_of_order_exit_same_name(conn, exit_error): + conn.autocommit = True + + t1 = conn.transaction("save") + t1.__enter__() + t2 = conn.transaction("save") + t2.__enter__() + + with pytest.raises(e.ProgrammingError): + t1.__exit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + t2.__exit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("what", ["commit", "rollback", "error"]) +def test_concurrency(conn, what): + conn.autocommit = True + + evs = [Event() for i in range(3)] + + def worker(unlock, wait_on): + with pytest.raises(e.ProgrammingError) as ex: + with conn.transaction(): + unlock.set() + wait_on.wait() + conn.execute("select 1") + + if what == "error": + 1 / 0 + elif what == "rollback": + raise Rollback() + else: + assert what == "commit" + + if what == "error": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, ZeroDivisionError) + elif what == "rollback": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, Rollback) + else: + assert "transaction commit" in str(ex.value) + + # Start a first transaction in a thread + t1 = Thread(target=worker, kwargs={"unlock": evs[0], "wait_on": evs[1]}) + t1.start() + evs[0].wait() + + # Start a nested transaction in a thread + t2 = Thread(target=worker, kwargs={"unlock": evs[1], "wait_on": evs[2]}) + t2.start() + + # Terminate the first transaction before the second does + t1.join() + evs[2].set() + t2.join() diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py new file mode 100644 index 0000000..55e1c9c --- /dev/null +++ b/tests/test_transaction_async.py @@ -0,0 +1,743 @@ +import asyncio +import logging + +import pytest + +from psycopg import Rollback +from psycopg import errors as e +from psycopg._compat import create_task + +from .test_transaction import in_transaction, insert_row, inserted, get_exc_info +from .test_transaction import ExpectedException, crdb_skip_external_observer +from .test_transaction import create_test_table # noqa # autouse fixture + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +async def aconn(aconn, apipeline): + return aconn + + +async def test_basic(aconn, apipeline): + """Basic use of transaction() to BEGIN and COMMIT a transaction.""" + assert not in_transaction(aconn) + async with aconn.transaction(): + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + assert not in_transaction(aconn) + + +async def test_exposes_associated_connection(aconn): + """Transaction exposes its connection as a read-only property.""" + async with aconn.transaction() as tx: + assert tx.connection is aconn + with pytest.raises(AttributeError): + tx.connection = aconn + + +async def test_exposes_savepoint_name(aconn): + """Transaction exposes its savepoint name as a read-only property.""" + async with aconn.transaction(savepoint_name="foo") as tx: + assert tx.savepoint_name == "foo" + with pytest.raises(AttributeError): + tx.savepoint_name = "bar" + + +async def test_cant_reenter(aconn): + async with aconn.transaction() as tx: + pass + + with pytest.raises(TypeError): + async with tx: + pass + + +async def test_begins_on_enter(aconn, apipeline): + """Transaction does not begin until __enter__() is called.""" + tx = aconn.transaction() + assert not in_transaction(aconn) + async with tx: + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + assert not in_transaction(aconn) + + +async def test_commit_on_successful_exit(aconn): + """Changes are committed on successful exit from the `with` block.""" + async with aconn.transaction(): + await insert_row(aconn, "foo") + + assert not in_transaction(aconn) + assert await inserted(aconn) == {"foo"} + + +async def test_rollback_on_exception_exit(aconn): + """Changes are rolled back if an exception escapes the `with` block.""" + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise ExpectedException("This discards the insert") + + assert not in_transaction(aconn) + assert not await inserted(aconn) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_context_inerror_rollback_no_clobber( + aconn_cls, aconn, apipeline, dsn, caplog +): + if apipeline: + # Only 'aconn' is possibly in pipeline mode, but the transaction and + # checks are on 'conn2'. + pytest.skip("not applicable") + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + async with await aconn_cls.connect(dsn) as conn2: + async with conn2.transaction(): + await conn2.execute("select 1") + await aconn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + conn = await aconn_cls.connect(dsn) + try: + with pytest.raises(ZeroDivisionError): + async with conn.transaction(): + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + finally: + await conn.close() + + +async def test_interaction_dbapi_transaction(aconn): + await insert_row(aconn, "foo") + + async with aconn.transaction(): + await insert_row(aconn, "bar") + raise Rollback + + async with aconn.transaction(): + await insert_row(aconn, "baz") + + assert in_transaction(aconn) + await aconn.commit() + assert await inserted(aconn) == {"foo", "baz"} + + +async def test_prohibits_use_of_commit_rollback_autocommit(aconn): + """ + Within a Transaction block, it is forbidden to touch commit, rollback, + or the autocommit setting on the connection, as this would interfere + with the transaction scope being managed by the Transaction block. + """ + await aconn.set_autocommit(False) + await aconn.commit() + await aconn.rollback() + + async with aconn.transaction(): + with pytest.raises(e.ProgrammingError): + await aconn.set_autocommit(False) + with pytest.raises(e.ProgrammingError): + await aconn.commit() + with pytest.raises(e.ProgrammingError): + await aconn.rollback() + + await aconn.set_autocommit(False) + await aconn.commit() + await aconn.rollback() + + +@pytest.mark.parametrize("autocommit", [False, True]) +async def test_preserves_autocommit(aconn, autocommit): + """ + Connection.autocommit is unchanged both during and after Transaction block. + """ + await aconn.set_autocommit(autocommit) + async with aconn.transaction(): + assert aconn.autocommit is autocommit + assert aconn.autocommit is autocommit + + +async def test_autocommit_off_but_no_tx_started_successful_exit(aconn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are committed + """ + await aconn.set_autocommit(False) + assert not in_transaction(aconn) + async with aconn.transaction(): + await insert_row(aconn, "new") + assert not in_transaction(aconn) + + # Changes committed + assert await inserted(aconn) == {"new"} + assert inserted(svcconn) == {"new"} + + +async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made within Transaction context are discarded + """ + await aconn.set_autocommit(False) + assert not in_transaction(aconn) + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "new") + raise ExpectedException() + assert not in_transaction(aconn) + + # Changes discarded + assert not await inserted(aconn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +async def test_autocommit_off_and_tx_in_progress_successful_exit( + aconn, apipeline, svcconn +): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are left intact + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + await aconn.set_autocommit(False) + await insert_row(aconn, "prior") + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + async with aconn.transaction(): + await insert_row(aconn, "new") + assert in_transaction(aconn) + assert await inserted(aconn) == {"prior", "new"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +@crdb_skip_external_observer +async def test_autocommit_off_and_tx_in_progress_exception_exit( + aconn, apipeline, svcconn +): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made before the Transaction context are left intact + * Changes made within Transaction context are discarded + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + await aconn.set_autocommit(False) + await insert_row(aconn, "prior") + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "new") + raise ExpectedException() + assert in_transaction(aconn) + assert await inserted(aconn) == {"prior"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +async def test_nested_all_changes_persisted_on_successful_exit(aconn, svcconn): + """Changes from nested transaction contexts are all persisted on exit.""" + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + async with aconn.transaction(): + await insert_row(aconn, "inner") + await insert_row(aconn, "outer-after") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"outer-before", "inner", "outer-after"} + assert inserted(svcconn) == {"outer-before", "inner", "outer-after"} + + +async def test_nested_all_changes_discarded_on_outer_exception(aconn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in outer context escapes. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + assert not in_transaction(aconn) + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in inner context escapes the outer context. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + assert not in_transaction(aconn) + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_nested_inner_scope_exception_handled_in_outer_scope(aconn, svcconn): + """ + An exception escaping the inner transaction context causes changes made + within that inner context to be discarded, but the error can then be + handled in the outer context, allowing changes made in the outer context + (both before, and after, the inner context) to be successfully committed. + """ + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + await insert_row(aconn, "outer-after") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"outer-before", "outer-after"} + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +async def test_nested_three_levels_successful_exit(aconn, svcconn): + """Exercise management of more than one savepoint.""" + async with aconn.transaction(): # BEGIN + await insert_row(aconn, "one") + async with aconn.transaction(): # SAVEPOINT s1 + await insert_row(aconn, "two") + async with aconn.transaction(): # SAVEPOINT s2 + await insert_row(aconn, "three") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"one", "two", "three"} + assert inserted(svcconn) == {"one", "two", "three"} + + +async def test_named_savepoint_escapes_savepoint_name(aconn): + async with aconn.transaction("s-1"): + pass + async with aconn.transaction("s1; drop table students"): + pass + + +async def test_named_savepoints_successful_exit(aconn, acommands): + """ + Entering a transaction context will do one of these these things: + 1. Begin an outer transaction (if one isn't already in progress) + 2. Begin an outer transaction and create a savepoint (if one is named) + 3. Create a savepoint (if a transaction is already in progress) + either using the name provided, or auto-generating a savepoint name. + + ...and exiting the context successfully will "commit" the same. + """ + commands = acommands + + # Case 1 + # Using Transaction explicitly because conn.transaction() enters the contetx + async with aconn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + assert commands.popall() == ["COMMIT"] + + # Case 1 (with a transaction already started) + await aconn.cursor().execute("select 1") + assert commands.popall() == ["BEGIN"] + async with aconn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_1"'] + assert tx.savepoint_name == "_pg3_1" + + assert commands.popall() == ['RELEASE "_pg3_1"'] + await aconn.rollback() + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + async with aconn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name provided) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + async with aconn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + assert commands.popall() == ['RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + async with aconn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + assert commands.popall() == ['RELEASE "_pg3_2"'] + assert commands.popall() == ["COMMIT"] + + +async def test_named_savepoints_exception_exit(aconn, acommands): + """ + Same as the previous test but checks that when exiting the context with an + exception, whatever transaction and/or savepoint was started on enter will + be rolled-back as appropriate. + """ + commands = acommands + + # Case 1 + with pytest.raises(ExpectedException): + async with aconn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + with pytest.raises(ExpectedException): + async with aconn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 3 (with savepoint name provided) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + async with aconn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + raise ExpectedException + assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + async with aconn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + raise ExpectedException + assert commands.popall() == [ + 'ROLLBACK TO "_pg3_2"', + 'RELEASE "_pg3_2"', + ] + assert commands.popall() == ["COMMIT"] + + +async def test_named_savepoints_with_repeated_names_works(aconn): + """ + Using the same savepoint name repeatedly works correctly, but bypasses + some sanity checks. + """ + # Works correctly if no inner transactions are rolled back + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("sp"): + await insert_row(aconn, "tx1") + async with aconn.transaction("sp"): + await insert_row(aconn, "tx2") + async with aconn.transaction("sp"): + await insert_row(aconn, "tx3") + assert await inserted(aconn) == {"tx1", "tx2", "tx3"} + + # Works correctly if one level of inner transaction is rolled back + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("s1"): + await insert_row(aconn, "tx1") + async with aconn.transaction("s1", force_rollback=True): + await insert_row(aconn, "tx2") + async with aconn.transaction("s1"): + await insert_row(aconn, "tx3") + assert await inserted(aconn) == {"tx1"} + assert await inserted(aconn) == {"tx1"} + + # Works correctly if multiple inner transactions are rolled back + # (This scenario mandates releasing savepoints after rolling back to them.) + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("s1"): + await insert_row(aconn, "tx1") + async with aconn.transaction("s1") as tx2: + await insert_row(aconn, "tx2") + async with aconn.transaction("s1"): + await insert_row(aconn, "tx3") + raise Rollback(tx2) + assert await inserted(aconn) == {"tx1"} + assert await inserted(aconn) == {"tx1"} + + +async def test_force_rollback_successful_exit(aconn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + async with aconn.transaction(force_rollback=True): + await insert_row(aconn, "foo") + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_force_rollback_exception_exit(aconn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(force_rollback=True): + await insert_row(aconn, "foo") + raise ExpectedException() + assert not await inserted(aconn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +async def test_explicit_rollback_discards_changes(aconn, svcconn): + """ + Raising a Rollback exception in the middle of a block exits the block and + discards all changes made within that block. + + You can raise any of the following: + - Rollback (type) + - Rollback() (instance) + - Rollback(tx) (instance initialised with reference to the transaction) + All of these are equivalent. + """ + + async def assert_no_rows(): + assert not await inserted(aconn) + assert not inserted(svcconn) + + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise Rollback + await assert_no_rows() + + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise Rollback() + await assert_no_rows() + + async with aconn.transaction() as tx: + await insert_row(aconn, "foo") + raise Rollback(tx) + await assert_no_rows() + + +@crdb_skip_external_observer +async def test_explicit_rollback_outer_tx_unaffected(aconn, svcconn): + """ + Raising a Rollback exception in the middle of a block does not impact an + enclosing transaction block. + """ + async with aconn.transaction(): + await insert_row(aconn, "before") + async with aconn.transaction(): + await insert_row(aconn, "during") + raise Rollback + assert in_transaction(aconn) + assert not inserted(svcconn) + await insert_row(aconn, "after") + assert await inserted(aconn) == {"before", "after"} + assert inserted(svcconn) == {"before", "after"} + + +async def test_explicit_rollback_of_outer_transaction(aconn): + """ + Raising a Rollback exception that references an outer transaction will + discard all changes from both inner and outer transaction blocks. + """ + async with aconn.transaction() as outer_tx: + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise Rollback(outer_tx) + assert False, "This line of code should be unreachable." + assert not await inserted(aconn) + + +@crdb_skip_external_observer +async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(aconn, svcconn): + """ + Rolling-back an enclosing transaction does not impact an outer transaction. + """ + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + async with aconn.transaction() as tx_enclosing: + await insert_row(aconn, "enclosing") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise Rollback(tx_enclosing) + await insert_row(aconn, "outer-after") + + assert await inserted(aconn) == {"outer-before", "outer-after"} + assert not inserted(svcconn) # Not yet committed + # Changes committed + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +async def test_str(aconn, apipeline): + async with aconn.transaction() as tx: + if apipeline: + assert "[INTRANS]" not in str(tx) + await apipeline.sync() + assert "[INTRANS, pipeline=ON]" in str(tx) + else: + assert "[INTRANS]" in str(tx) + assert "(active)" in str(tx) + assert "'" not in str(tx) + async with aconn.transaction("wat") as tx2: + if apipeline: + assert "[INTRANS, pipeline=ON]" in str(tx2) + else: + assert "[INTRANS]" in str(tx2) + assert "'wat'" in str(tx2) + + if apipeline: + assert "[IDLE, pipeline=ON]" in str(tx) + else: + assert "[IDLE]" in str(tx) + assert "(terminated)" in str(tx) + + with pytest.raises(ZeroDivisionError): + async with aconn.transaction() as tx: + 1 / 0 + + assert "(terminated)" in str(tx) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +async def test_out_of_order_exit(aconn, exit_error): + await aconn.set_autocommit(True) + + t1 = aconn.transaction() + await t1.__aenter__() + + t2 = aconn.transaction() + await t2.__aenter__() + + with pytest.raises(e.ProgrammingError): + await t1.__aexit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + await t2.__aexit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +async def test_out_of_order_implicit_begin(aconn, exit_error): + await aconn.execute("select 1") + + t1 = aconn.transaction() + await t1.__aenter__() + + t2 = aconn.transaction() + await t2.__aenter__() + + with pytest.raises(e.ProgrammingError): + await t1.__aexit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + await t2.__aexit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +async def test_out_of_order_exit_same_name(aconn, exit_error): + await aconn.set_autocommit(True) + + t1 = aconn.transaction("save") + await t1.__aenter__() + t2 = aconn.transaction("save") + await t2.__aenter__() + + with pytest.raises(e.ProgrammingError): + await t1.__aexit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + await t2.__aexit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("what", ["commit", "rollback", "error"]) +async def test_concurrency(aconn, what): + await aconn.set_autocommit(True) + + evs = [asyncio.Event() for i in range(3)] + + async def worker(unlock, wait_on): + with pytest.raises(e.ProgrammingError) as ex: + async with aconn.transaction(): + unlock.set() + await wait_on.wait() + await aconn.execute("select 1") + + if what == "error": + 1 / 0 + elif what == "rollback": + raise Rollback() + else: + assert what == "commit" + + if what == "error": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, ZeroDivisionError) + elif what == "rollback": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, Rollback) + else: + assert "transaction commit" in str(ex.value) + + # Start a first transaction in a task + t1 = create_task(worker(unlock=evs[0], wait_on=evs[1])) + await evs[0].wait() + + # Start a nested transaction in a task + t2 = create_task(worker(unlock=evs[1], wait_on=evs[2])) + + # Terminate the first transaction before the second does + await asyncio.gather(t1) + evs[2].set() + await asyncio.gather(t2) diff --git a/tests/test_typeinfo.py b/tests/test_typeinfo.py new file mode 100644 index 0000000..d0e57e6 --- /dev/null +++ b/tests/test_typeinfo.py @@ -0,0 +1,145 @@ +import pytest + +import psycopg +from psycopg import sql +from psycopg.pq import TransactionStatus +from psycopg.types import TypeInfo + + +@pytest.mark.parametrize("name", ["text", sql.Identifier("text")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +def test_fetch(conn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + conn.execute("select 1") + + assert conn.info.transaction_status == status + info = TypeInfo.fetch(conn, name) + assert conn.info.transaction_status == status + + assert info.name == "text" + # TODO: add the schema? + # assert info.schema == "pg_catalog" + + assert info.oid == psycopg.adapters.types["text"].oid + assert info.array_oid == psycopg.adapters.types["text"].array_oid + assert info.regtype == "text" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name", ["text", sql.Identifier("text")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +async def test_fetch_async(aconn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + await aconn.execute("select 1") + + assert aconn.info.transaction_status == status + info = await TypeInfo.fetch(aconn, name) + assert aconn.info.transaction_status == status + + assert info.name == "text" + # assert info.schema == "pg_catalog" + assert info.oid == psycopg.adapters.types["text"].oid + assert info.array_oid == psycopg.adapters.types["text"].array_oid + + +@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +def test_fetch_not_found(conn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + conn.execute("select 1") + + assert conn.info.transaction_status == status + info = TypeInfo.fetch(conn, name) + assert conn.info.transaction_status == status + assert info is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +async def test_fetch_not_found_async(aconn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + await aconn.execute("select 1") + + assert aconn.info.transaction_status == status + info = await TypeInfo.fetch(aconn, name) + assert aconn.info.transaction_status == status + + assert info is None + + +@pytest.mark.crdb_skip("composite") +@pytest.mark.parametrize( + "name", ["testschema.testtype", sql.Identifier("testschema", "testtype")] +) +def test_fetch_by_schema_qualified_string(conn, name): + conn.execute("create schema if not exists testschema") + conn.execute("create type testschema.testtype as (foo text)") + + info = TypeInfo.fetch(conn, name) + assert info.name == "testtype" + # assert info.schema == "testschema" + cur = conn.execute( + """ + select oid, typarray from pg_type + where oid = 'testschema.testtype'::regtype + """ + ) + assert cur.fetchone() == (info.oid, info.array_oid) + + +@pytest.mark.parametrize( + "name", + [ + "text", + # TODO: support these? + # "pg_catalog.text", + # sql.Identifier("text"), + # sql.Identifier("pg_catalog", "text"), + ], +) +def test_registry_by_builtin_name(conn, name): + info = psycopg.adapters.types[name] + assert info.name == "text" + assert info.oid == 25 + + +def test_registry_empty(): + r = psycopg.types.TypesRegistry() + assert r.get("text") is None + with pytest.raises(KeyError): + r["text"] + + +@pytest.mark.parametrize("oid, aoid", [(1, 2), (1, 0), (0, 2), (0, 0)]) +def test_registry_invalid_oid(oid, aoid): + r = psycopg.types.TypesRegistry() + ti = psycopg.types.TypeInfo("test", oid, aoid) + r.add(ti) + assert r["test"] is ti + if oid: + assert r[oid] is ti + if aoid: + assert r[aoid] is ti + with pytest.raises(KeyError): + r[0] + + +def test_registry_copy(): + r = psycopg.types.TypesRegistry(psycopg.postgres.types) + assert r.get("text") is r["text"] is r[25] + assert r["text"].oid == 25 + + +def test_registry_isolated(): + orig = psycopg.postgres.types + tinfo = orig["text"] + r = psycopg.types.TypesRegistry(orig) + tdummy = psycopg.types.TypeInfo("dummy", tinfo.oid, tinfo.array_oid) + r.add(tdummy) + assert r[25] is r["dummy"] is tdummy + assert orig[25] is r["text"] is tinfo diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 0000000..fff9cec --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,449 @@ +import os + +import pytest + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +@pytest.mark.parametrize( + "filename", + ["adapters_example.py", "typing_example.py"], +) +def test_typing_example(mypy, filename): + cp = mypy.run_on_file(os.path.join(HERE, filename)) + errors = cp.stdout.decode("utf8", "replace").splitlines() + assert not errors + assert cp.returncode == 0 + + +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg.connect()", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.tuple_row)", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Connection[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.namedtuple_row)", + "psycopg.Connection[NamedTuple]", + ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "psycopg.Connection[Thing]", + ), + ( + "psycopg.connect(row_factory=thing_row)", + "psycopg.Connection[Thing]", + ), + ( + "psycopg.Connection.connect()", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.Connection.connect(row_factory=rows.dict_row)", + "psycopg.Connection[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy): + stmts = f"obj = {conn}" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg.connect()", + "conn.cursor()", + "psycopg.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor()", + "psycopg.Cursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor(row_factory=rows.namedtuple_row)", + "psycopg.Cursor[NamedTuple]", + ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "conn.cursor()", + "psycopg.Cursor[Thing]", + ), + ( + "psycopg.connect(row_factory=thing_row)", + "conn.cursor()", + "psycopg.Cursor[Thing]", + ), + ( + "psycopg.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg.Cursor[Thing]", + ), + # Async cursors + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor()", + "psycopg.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg.connect()", + "conn.cursor(name='foo')", + "psycopg.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + # Async server-side cursors + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(name='foo')", + "psycopg.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type(conn, curs, type, mypy): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg.connect()", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn, row_factory=rows.namedtuple_row)", + "psycopg.Cursor[NamedTuple]", + ), + # Async cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn, row_factory=thing_row)", + "psycopg.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg.connect()", + "psycopg.ServerCursor(conn, 'foo')", + "psycopg.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, name='foo')", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, 'foo', row_factory=rows.namedtuple_row)", + "psycopg.ServerCursor[NamedTuple]", + ), + # Async server-side cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo', row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type_init(conn, curs, type, mypy): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "Optional[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "Optional[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "Optional[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchone_type(conn_class, server_side, curs, type, mypy): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.fetchone() +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "Tuple[Any, ...]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "Dict[str, Any]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "Thing", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_iter_type(conn_class, server_side, curs, type, mypy): + if "Async" in conn_class: + async_ = "async " + await_ = "await " + else: + async_ = await_ = "" + + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_}psycopg.{conn_class}.connect() +curs = {curs} +{async_}for obj in curs: + pass +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize("method", ["fetchmany", "fetchall"]) +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "List[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "List[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "List[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.{method}() +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_cur_subclass_execute(mypy, conn_class, server_side): + async_ = "async " if "Async" in conn_class else "" + await_ = "await" if "Async" in conn_class else "" + cur_base_class = "".join( + [ + "Async" if "Async" in conn_class else "", + "Server" if server_side else "", + "Cursor", + ] + ) + cur_name = "'foo'" if server_side else "" + + src = f"""\ +from typing import Any, cast +import psycopg +from psycopg.rows import Row, TupleRow + +class MyCursor(psycopg.{cur_base_class}[Row]): + pass + +{async_}def test() -> None: + conn = {await_} psycopg.{conn_class}.connect() + + cur: MyCursor[TupleRow] + reveal_type(cur) + + cur = cast(MyCursor[TupleRow], conn.cursor({cur_name})) + {async_}with cur as cur2: + reveal_type(cur2) + cur3 = {await_} cur2.execute("") + reveal_type(cur3) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 3 + types = [mypy.get_revealed(line) for line in out] + assert types[0] == types[1] + assert types[0] == types[2] + + +def _test_reveal(stmts, type, mypy): + ignore = "" if type.startswith("Optional") else "# type: ignore[assignment]" + stmts = "\n".join(f" {line}" for line in stmts.splitlines()) + + src = f"""\ +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence +from typing import Tuple, Union +import psycopg +from psycopg import rows + +class Thing: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + +def thing_row( + cur: Union[psycopg.Cursor[Any], psycopg.AsyncCursor[Any]], +) -> Callable[[Sequence[Any]], Thing]: + assert cur.description + names = [d.name for d in cur.description] + + def make_row(t: Sequence[Any]) -> Thing: + return Thing(**dict(zip(names, t))) + + return make_row + +async def tmp() -> None: +{stmts} + reveal_type(obj) + +ref: {type} = None {ignore} +reveal_type(ref) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 2, "\n".join(out) + got, want = [mypy.get_revealed(line) for line in out] + assert got == want + + +@pytest.mark.xfail(reason="https://github.com/psycopg/psycopg/issues/308") +@pytest.mark.parametrize( + "conn, type", + [ + ( + "MyConnection.connect()", + "MyConnection[Tuple[Any, ...]]", + ), + ( + "MyConnection.connect(row_factory=rows.tuple_row)", + "MyConnection[Tuple[Any, ...]]", + ), + ( + "MyConnection.connect(row_factory=rows.dict_row)", + "MyConnection[Dict[str, Any]]", + ), + ], +) +def test_generic_connect(conn, type, mypy): + src = f""" +from typing import Any, Dict, Tuple +import psycopg +from psycopg import rows + +class MyConnection(psycopg.Connection[rows.Row]): + pass + +obj = {conn} +reveal_type(obj) + +ref: {type} = None # type: ignore[assignment] +reveal_type(ref) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 2, "\n".join(out) + got, want = [mypy.get_revealed(line) for line in out] + assert got == want diff --git a/tests/test_waiting.py b/tests/test_waiting.py new file mode 100644 index 0000000..63237e8 --- /dev/null +++ b/tests/test_waiting.py @@ -0,0 +1,159 @@ +import select # noqa: used in pytest.mark.skipif +import socket +import sys + +import pytest + +import psycopg +from psycopg import waiting +from psycopg import generators +from psycopg.pq import ConnStatus, ExecStatus + +skip_if_not_linux = pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="non-Linux platform" +) + +waitfns = [ + "wait", + "wait_selector", + pytest.param( + "wait_select", marks=pytest.mark.skipif("not hasattr(select, 'select')") + ), + pytest.param( + "wait_epoll", marks=pytest.mark.skipif("not hasattr(select, 'epoll')") + ), + pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")), +] + +timeouts = [pytest.param({}, id="blank")] +timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]] + + +@pytest.mark.parametrize("timeout", timeouts) +def test_wait_conn(dsn, timeout): + gen = generators.connect(dsn) + conn = waiting.wait_conn(gen, **timeout) + assert conn.status == ConnStatus.OK + + +def test_wait_conn_bad(dsn): + gen = generators.connect("dbname=nosuchdb") + with pytest.raises(psycopg.OperationalError): + waiting.wait_conn(gen) + + +@pytest.mark.parametrize("waitfn", waitfns) +@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready)) +@skip_if_not_linux +def test_wait_ready(waitfn, wait, ready): + waitfn = getattr(waiting, waitfn) + + def gen(): + r = yield wait + return r + + with socket.socket() as s: + r = waitfn(gen(), s.fileno()) + assert r & ready + + +@pytest.mark.parametrize("waitfn", waitfns) +@pytest.mark.parametrize("timeout", timeouts) +def test_wait(pgconn, waitfn, timeout): + waitfn = getattr(waiting, waitfn) + + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + (res,) = waitfn(gen, pgconn.socket, **timeout) + assert res.status == ExecStatus.TUPLES_OK + + +@pytest.mark.parametrize("waitfn", waitfns) +def test_wait_bad(pgconn, waitfn): + waitfn = getattr(waiting, waitfn) + + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + waitfn(gen, pgconn.socket) + + +@pytest.mark.slow +@pytest.mark.skipif( + "sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious" +) +@pytest.mark.parametrize("waitfn", waitfns) +def test_wait_large_fd(dsn, waitfn): + waitfn = getattr(waiting, waitfn) + + files = [] + try: + try: + for i in range(1100): + files.append(open(__file__)) + except OSError: + pytest.skip("can't open the number of files needed for the test") + + pgconn = psycopg.pq.PGconn.connect(dsn.encode()) + try: + assert pgconn.socket > 1024 + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + if waitfn is waiting.wait_select: + with pytest.raises(ValueError): + waitfn(gen, pgconn.socket) + else: + (res,) = waitfn(gen, pgconn.socket) + assert res.status == ExecStatus.TUPLES_OK + finally: + pgconn.finish() + finally: + for f in files: + f.close() + + +@pytest.mark.parametrize("timeout", timeouts) +@pytest.mark.asyncio +async def test_wait_conn_async(dsn, timeout): + gen = generators.connect(dsn) + conn = await waiting.wait_conn_async(gen, **timeout) + assert conn.status == ConnStatus.OK + + +@pytest.mark.asyncio +async def test_wait_conn_async_bad(dsn): + gen = generators.connect("dbname=nosuchdb") + with pytest.raises(psycopg.OperationalError): + await waiting.wait_conn_async(gen) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready)) +@skip_if_not_linux +async def test_wait_ready_async(wait, ready): + def gen(): + r = yield wait + return r + + with socket.socket() as s: + r = await waiting.wait_async(gen(), s.fileno()) + assert r & ready + + +@pytest.mark.asyncio +async def test_wait_async(pgconn): + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + (res,) = await waiting.wait_async(gen, pgconn.socket) + assert res.status == ExecStatus.TUPLES_OK + + +@pytest.mark.asyncio +async def test_wait_async_bad(pgconn): + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + socket = pgconn.socket + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + await waiting.wait_async(gen, socket) diff --git a/tests/test_windows.py b/tests/test_windows.py new file mode 100644 index 0000000..09e61ba --- /dev/null +++ b/tests/test_windows.py @@ -0,0 +1,23 @@ +import pytest +import asyncio +import sys + +from psycopg.errors import InterfaceError + + +@pytest.mark.skipif(sys.platform != "win32", reason="windows only test") +def test_windows_error(aconn_cls, dsn): + loop = asyncio.ProactorEventLoop() # type: ignore[attr-defined] + + async def go(): + with pytest.raises( + InterfaceError, + match="Psycopg cannot use the 'ProactorEventLoop'", + ): + await aconn_cls.connect(dsn) + + try: + loop.run_until_complete(go()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/types/__init__.py diff --git a/tests/types/test_array.py b/tests/types/test_array.py new file mode 100644 index 0000000..74c17a6 --- /dev/null +++ b/tests/types/test_array.py @@ -0,0 +1,338 @@ +from typing import List, Any +from decimal import Decimal + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat, Transformer, Dumper +from psycopg.types import TypeInfo +from psycopg._compat import prod +from psycopg.postgres import types as builtins + + +tests_str = [ + ([[[[[["a"]]]]]], "{{{{{{a}}}}}}"), + ([[[[[[None]]]]]], "{{{{{{NULL}}}}}}"), + ([[[[[["NULL"]]]]]], '{{{{{{"NULL"}}}}}}'), + (["foo", "bar", "baz"], "{foo,bar,baz}"), + (["foo", None, "baz"], "{foo,null,baz}"), + (["foo", "null", "", "baz"], '{foo,"null","",baz}'), + ( + [["foo", "bar"], ["baz", "qux"], ["quux", "quuux"]], + "{{foo,bar},{baz,qux},{quux,quuux}}", + ), + ( + [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]], + r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}', + ), +] + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("type", ["text", "int4"]) +def test_dump_empty_list(conn, fmt_in, type): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}::{type}[] = %s::{type}[]", ([], "{}")) + assert cur.fetchone()[0] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("obj, want", tests_str) +def test_dump_list_str(conn, obj, want, fmt_in): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}::text[] = %s::text[]", (obj, want)) + assert cur.fetchone()[0] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_empty_list_str(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::text[]", ([],)) + assert cur.fetchone()[0] == [] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("want, obj", tests_str) +def test_load_list_str(conn, obj, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::text[]", (obj,)) + assert cur.fetchone()[0] == want + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_all_chars(conn, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + c = chr(i) + cur.execute(f"select %{fmt_in.value}::text[]", ([c],)) + assert cur.fetchone()[0] == [c] + + a = list(map(chr, range(1, 256))) + a.append("\u20ac") + cur.execute(f"select %{fmt_in.value}::text[]", (a,)) + assert cur.fetchone()[0] == a + + s = "".join(a) + cur.execute(f"select %{fmt_in.value}::text[]", ([s],)) + assert cur.fetchone()[0] == [s] + + +tests_int = [ + ([10, 20, -30], "{10,20,-30}"), + ([10, None, 30], "{10,null,30}"), + ([[10, 20], [30, 40]], "{{10,20},{30,40}}"), +] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("obj, want", tests_int) +def test_dump_list_int(conn, obj, want): + cur = conn.cursor() + cur.execute("select %s::int[] = %s::int[]", (obj, want)) + assert cur.fetchone()[0] + + +@pytest.mark.parametrize( + "input", + [ + [["a"], ["b", "c"]], + [["a"], []], + [[["a"]], ["b"]], + # [["a"], [["b"]]], # todo, but expensive (an isinstance per item) + # [True, b"a"], # TODO expensive too + ], +) +def test_bad_binary_array(input): + tx = Transformer() + with pytest.raises(psycopg.DataError): + tx.get_dumper(input, PyFormat.BINARY).dump(input) + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("want, obj", tests_int) +def test_load_list_int(conn, obj, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::int[]", (obj,)) + assert cur.fetchone()[0] == want + + stmt = sql.SQL("copy (select {}::int[]) to stdout (format {})").format( + obj, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["int4[]"]) + (got,) = copy.read_row() + + assert got == want + + +@pytest.mark.crdb_skip("composite") +def test_array_register(conn): + conn.execute("create table mytype (data text)") + cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""") + res = cur.fetchone() + assert res[0] == "(foo)" + assert res[1] == "{(foo)}" + + info = TypeInfo.fetch(conn, "mytype") + info.register(conn) + + cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""") + res = cur.fetchone() + assert res[0] == "(foo)" + assert res[1] == ["(foo)"] + + +@pytest.mark.crdb("skip", reason="aclitem") +def test_array_of_unknown_builtin(conn): + user = conn.execute("select user").fetchone()[0] + # we cannot load this type, but we understand it is an array + val = f"{user}=arwdDxt/{user}" + cur = conn.execute(f"select '{val}'::aclitem, array['{val}']::aclitem[]") + res = cur.fetchone() + assert cur.description[0].type_code == builtins["aclitem"].oid + assert res[0] == val + assert cur.description[1].type_code == builtins["aclitem"].array_oid + assert res[1] == [val] + + +@pytest.mark.parametrize( + "num, type", + [ + (0, "int2"), + (2**15 - 1, "int2"), + (-(2**15), "int2"), + (2**15, "int4"), + (2**31 - 1, "int4"), + (-(2**31), "int4"), + (2**31, "int8"), + (2**63 - 1, "int8"), + (-(2**63), "int8"), + (2**63, "numeric"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_numbers_array(num, type, fmt_in): + for array in ([num], [1, num]): + tx = Transformer() + dumper = tx.get_dumper(array, fmt_in) + dumper.dump(array) + assert dumper.oid == builtins[type].array_oid + + +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split()) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out): + wrapper = getattr(psycopg.types.numeric, wrapper) + if wrapper is Decimal: + want_cls = Decimal + else: + assert wrapper.__mro__[1] in (int, float) + want_cls = wrapper.__mro__[1] + + obj = [wrapper(1), wrapper(0), wrapper(-1), None] + cur = conn.cursor(binary=fmt_out) + got = cur.execute(f"select %{fmt_in.value}", [obj]).fetchone()[0] + assert got == obj + for i in got: + if i is not None: + assert type(i) is want_cls + + +def test_mix_types(conn): + with pytest.raises(psycopg.DataError): + conn.execute("select %s", ([1, 0.5],)) + + with pytest.raises(psycopg.DataError): + conn.execute("select %s", ([1, Decimal("0.5")],)) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_empty_list_mix(conn, fmt_in): + objs = list(range(3)) + conn.execute("create table testarrays (col1 bigint[], col2 bigint[])") + # pro tip: don't get confused with the types + f1, f2 = conn.execute( + f"insert into testarrays values (%{fmt_in.value}, %{fmt_in.value}) returning *", + (objs, []), + ).fetchone() + assert f1 == objs + assert f2 == [] + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_empty_list(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table test (id serial primary key, data date[])") + with conn.transaction(): + cur.execute( + f"insert into test (data) values (%{fmt_in.value}) returning id", ([],) + ) + id = cur.fetchone()[0] + cur.execute("select data from test") + assert cur.fetchone() == ([],) + + # test untyped list in a filter + cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([id],)) + assert cur.fetchone() + cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([],)) + assert not cur.fetchone() + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_empty_list_after_choice(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table test (id serial primary key, data float[])") + cur.executemany( + f"insert into test (data) values (%{fmt_in.value})", [([1.0],), ([],)] + ) + cur.execute("select data from test order by id") + assert cur.fetchall() == [([1.0],), ([],)] + + +@pytest.mark.crdb_skip("geometric types") +def test_dump_list_no_comma_separator(conn): + class Box: + def __init__(self, x1, y1, x2, y2): + self.coords = (x1, y1, x2, y2) + + class BoxDumper(Dumper): + + format = pq.Format.TEXT + oid = psycopg.postgres.types["box"].oid + + def dump(self, box): + return ("(%s,%s),(%s,%s)" % box.coords).encode() + + conn.adapters.register_dumper(Box, BoxDumper) + + cur = conn.execute("select (%s::box)::text", (Box(1, 2, 3, 4),)) + got = cur.fetchone()[0] + assert got == "(3,4),(1,2)" + + cur = conn.execute( + "select (%s::box[])::text", ([Box(1, 2, 3, 4), Box(5, 4, 3, 2)],) + ) + got = cur.fetchone()[0] + assert got == "{(3,4),(1,2);(5,4),(3,2)}" + + +@pytest.mark.crdb_skip("geometric types") +def test_load_array_no_comma_separator(conn): + cur = conn.execute("select '{(2,2),(1,1);(5,6),(3,4)}'::box[]") + # Not parsed at the moment, but split ok on ; separator + assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_nested_array(conn, fmt_out): + dims = [3, 4, 5, 6] + a: List[Any] = list(range(prod(dims))) + for dim in dims[-1:0:-1]: + a = [a[i : i + dim] for i in range(0, len(a), dim)] + + assert a[2][3][4][5] == prod(dims) - 1 + + sa = str(a).replace("[", "{").replace("]", "}") + got = conn.execute("select %s::int[][][][]", [sa], binary=fmt_out).fetchone()[0] + assert got == a + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize( + "obj, want", + [ + ("'[0:1]={a,b}'::text[]", ["a", "b"]), + ("'[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}'::int[]", [[[1, 2, 3], [4, 5, 6]]]), + ], +) +def test_array_with_bounds(conn, obj, want, fmt_out): + got = conn.execute(f"select {obj}", binary=fmt_out).fetchone()[0] + assert got == want + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_all_chars_with_bounds(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + c = chr(i) + cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([c],)) + assert cur.fetchone()[0] == ["a", "b", c] + + a = list(map(chr, range(1, 256))) + a.append("\u20ac") + cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", (a,)) + assert cur.fetchone()[0] == ["a", "b"] + a + + s = "".join(a) + cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([s],)) + assert cur.fetchone()[0] == ["a", "b", s] diff --git a/tests/types/test_bool.py b/tests/types/test_bool.py new file mode 100644 index 0000000..edd4dad --- /dev/null +++ b/tests/types/test_bool.py @@ -0,0 +1,47 @@ +import pytest + +from psycopg import pq +from psycopg import sql +from psycopg.adapt import Transformer, PyFormat +from psycopg.postgres import types as builtins + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("b", [True, False]) +def test_roundtrip_bool(conn, b, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + result = cur.execute(f"select %{fmt_in.value}", (b,)).fetchone()[0] + assert cur.pgresult.fformat(0) == fmt_out + if b is not None: + assert cur.pgresult.ftype(0) == builtins["bool"].oid + assert result is b + + result = cur.execute(f"select %{fmt_in.value}", ([b],)).fetchone()[0] + assert cur.pgresult.fformat(0) == fmt_out + if b is not None: + assert cur.pgresult.ftype(0) == builtins["bool"].array_oid + assert result[0] is b + + +@pytest.mark.parametrize("val", [True, False]) +def test_quote_bool(conn, val): + + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(val).lower().encode( + "ascii" + ) + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val))) + assert cur.fetchone()[0] is val + + +def test_quote_none(conn): + + tx = Transformer() + assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL" + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None))) + assert cur.fetchone()[0] is None diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py new file mode 100644 index 0000000..47beecf --- /dev/null +++ b/tests/types/test_composite.py @@ -0,0 +1,396 @@ +import pytest + +from psycopg import pq, postgres, sql +from psycopg.adapt import PyFormat +from psycopg.postgres import types as builtins +from psycopg.types.range import Range +from psycopg.types.composite import CompositeInfo, register_composite +from psycopg.types.composite import TupleDumper, TupleBinaryDumper + +from ..utils import eur +from ..fix_crdb import is_crdb, crdb_skip_message + + +pytestmark = pytest.mark.crdb_skip("composite") + +tests_str = [ + ("", ()), + # Funnily enough there's no way to represent (None,) in Postgres + ("null", ()), + ("null,null", (None, None)), + ("null, ''", (None, "")), + ( + "42,'foo','ba,r','ba''z','qu\"x'", + ("42", "foo", "ba,r", "ba'z", 'qu"x'), + ), + ("'foo''', '''foo', '\"bar', 'bar\"' ", ("foo'", "'foo", '"bar', 'bar"')), +] + + +@pytest.mark.parametrize("rec, want", tests_str) +def test_load_record(conn, want, rec): + cur = conn.cursor() + res = cur.execute(f"select row({rec})").fetchone()[0] + assert res == want + + +@pytest.mark.parametrize("rec, obj", tests_str) +def test_dump_tuple(conn, rec, obj): + cur = conn.cursor() + fields = [f"f{i} text" for i in range(len(obj))] + cur.execute( + f""" + drop type if exists tmptype; + create type tmptype as ({', '.join(fields)}); + """ + ) + info = CompositeInfo.fetch(conn, "tmptype") + register_composite(info, conn) + + res = conn.execute("select %s::tmptype", [obj]).fetchone()[0] + assert res == obj + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_all_chars(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0] + assert res == (chr(i),) + + cur.execute("select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256))) + res = cur.fetchone()[0] + assert res == tuple(map(chr, range(1, 256))) + + s = "".join(map(chr, range(1, 256))) + res = cur.execute("select row(%s::text)", [s]).fetchone()[0] + assert res == (s,) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty_range(conn, fmt_in): + conn.execute( + """ + drop type if exists tmptype; + create type tmptype as (num integer, range daterange, nums integer[]) + """ + ) + info = CompositeInfo.fetch(conn, "tmptype") + register_composite(info, conn) + + cur = conn.execute( + f"select pg_typeof(%{fmt_in.value})", + [info.python_type(10, Range(empty=True), [])], + ) + assert cur.fetchone()[0] == "tmptype" + + +@pytest.mark.parametrize( + "rec, want", + [ + ("", ()), + ("null", (None,)), # Unlike text format, this is a thing + ("null,null", (None, None)), + ("null, ''", (None, b"")), + ( + "42,'foo','ba,r','ba''z','qu\"x'", + (42, b"foo", b"ba,r", b"ba'z", b'qu"x'), + ), + ( + "'foo''', '''foo', '\"bar', 'bar\"' ", + (b"foo'", b"'foo", b'"bar', b'bar"'), + ), + ( + "10::int, null::text, 20::float, null::text, 'foo'::text, 'bar'::bytea ", + (10, None, 20.0, None, "foo", b"bar"), + ), + ], +) +def test_load_record_binary(conn, want, rec): + cur = conn.cursor(binary=True) + res = cur.execute(f"select row({rec})").fetchone()[0] + assert res == want + for o1, o2 in zip(res, want): + assert type(o1) is type(o2) + + +@pytest.fixture(scope="session") +def testcomp(svcconn): + if is_crdb(svcconn): + pytest.skip(crdb_skip_message("composite")) + cur = svcconn.cursor() + cur.execute( + """ + create schema if not exists testschema; + + drop type if exists testcomp cascade; + drop type if exists testschema.testcomp cascade; + + create type testcomp as (foo text, bar int8, baz float8); + create type testschema.testcomp as (foo text, bar int8, qux bool); + """ + ) + return CompositeInfo.fetch(svcconn, "testcomp") + + +fetch_cases = [ + ( + "testcomp", + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + "testschema.testcomp", + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), + ( + sql.Identifier("testcomp"), + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + sql.Identifier("testschema", "testcomp"), + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), +] + + +@pytest.mark.parametrize("name, fields", fetch_cases) +def test_fetch_info(conn, testcomp, name, fields): + info = CompositeInfo.fetch(conn, name) + assert info.name == "testcomp" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.field_names) == 3 + assert len(info.field_types) == 3 + for i, (name, t) in enumerate(fields): + assert info.field_names[i] == name + assert info.field_types[i] == builtins[t].oid + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, fields", fetch_cases) +async def test_fetch_info_async(aconn, testcomp, name, fields): + info = await CompositeInfo.fetch(aconn, name) + assert info.name == "testcomp" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.field_names) == 3 + assert len(info.field_types) == 3 + for i, (name, t) in enumerate(fields): + assert info.field_names[i] == name + assert info.field_types[i] == builtins[t].oid + + +@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT]) +def test_dump_tuple_all_chars(conn, fmt_in, testcomp): + cur = conn.cursor() + for i in range(1, 256): + (res,) = cur.execute( + f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}::testcomp", + (i, (chr(i), 1, 1.0)), + ).fetchone() + assert res is True + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_composite_all_chars(conn, fmt_in, testcomp): + cur = conn.cursor() + register_composite(testcomp, cur) + factory = testcomp.python_type + for i in range(1, 256): + obj = factory(chr(i), 1, 1.0) + (res,) = cur.execute( + f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}", (i, obj) + ).fetchone() + assert res is True + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_composite_null(conn, fmt_in, testcomp): + cur = conn.cursor() + register_composite(testcomp, cur) + factory = testcomp.python_type + + obj = factory("foo", 1, None) + rec = cur.execute( + f""" + select row('foo', 1, NULL)::testcomp = %(obj){fmt_in.value}, + %(obj){fmt_in.value}::text + """, + {"obj": obj}, + ).fetchone() + assert rec[0] is True, rec[1] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_composite(conn, testcomp, fmt_out): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info, conn) + + cur = conn.cursor(binary=fmt_out) + res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] + assert res.foo == "hello" + assert res.bar == 10 + assert res.baz == 20.0 + assert isinstance(res.baz, float) + + res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0] + assert len(res) == 1 + assert res[0].baz == 30.0 + assert isinstance(res[0].baz, float) + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_composite_factory(conn, testcomp, fmt_out): + info = CompositeInfo.fetch(conn, "testcomp") + + class MyThing: + def __init__(self, *args): + self.foo, self.bar, self.baz = args + + register_composite(info, conn, factory=MyThing) + assert info.python_type is MyThing + + cur = conn.cursor(binary=fmt_out) + res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] + assert isinstance(res, MyThing) + assert res.baz == 20.0 + assert isinstance(res.baz, float) + + res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0] + assert len(res) == 1 + assert res[0].baz == 30.0 + assert isinstance(res[0].baz, float) + + +def test_register_scope(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info) + for fmt in pq.Format: + for oid in (info.oid, info.array_oid): + assert postgres.adapters._loaders[fmt].pop(oid) + + for f in PyFormat: + assert postgres.adapters._dumpers[f].pop(info.python_type) + + cur = conn.cursor() + register_composite(info, cur) + for fmt in pq.Format: + for oid in (info.oid, info.array_oid): + assert oid not in postgres.adapters._loaders[fmt] + assert oid not in conn.adapters._loaders[fmt] + assert oid in cur.adapters._loaders[fmt] + + register_composite(info, conn) + for fmt in pq.Format: + for oid in (info.oid, info.array_oid): + assert oid not in postgres.adapters._loaders[fmt] + assert oid in conn.adapters._loaders[fmt] + + +def test_type_dumper_registered(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info, conn) + assert issubclass(info.python_type, tuple) + assert info.python_type.__name__ == "testcomp" + d = conn.adapters.get_dumper(info.python_type, "s") + assert issubclass(d, TupleDumper) + assert d is not TupleDumper + + tc = info.python_type("foo", 42, 3.14) + cur = conn.execute("select pg_typeof(%s)", [tc]) + assert cur.fetchone()[0] == "testcomp" + + +def test_type_dumper_registered_binary(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info, conn) + assert issubclass(info.python_type, tuple) + assert info.python_type.__name__ == "testcomp" + d = conn.adapters.get_dumper(info.python_type, "b") + assert issubclass(d, TupleBinaryDumper) + assert d is not TupleBinaryDumper + + tc = info.python_type("foo", 42, 3.14) + cur = conn.execute("select pg_typeof(%b)", [tc]) + assert cur.fetchone()[0] == "testcomp" + + +def test_callable_dumper_not_registered(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + + def fac(*args): + return args + (args[-1],) + + register_composite(info, conn, factory=fac) + assert info.python_type is None + + # but the loader is registered + cur = conn.execute("select '(foo,42,3.14)'::testcomp") + assert cur.fetchone()[0] == ("foo", 42, 3.14, 3.14) + + +def test_no_info_error(conn): + with pytest.raises(TypeError, match="composite"): + register_composite(None, conn) # type: ignore[arg-type] + + +def test_invalid_fields_names(conn): + conn.execute("set client_encoding to utf8") + conn.execute( + f""" + create type "a-b" as ("c-d" text, "{eur}" int); + create type "-x-{eur}" as ("w-ww" "a-b", "0" int); + """ + ) + ab = CompositeInfo.fetch(conn, '"a-b"') + x = CompositeInfo.fetch(conn, f'"-x-{eur}"') + register_composite(ab, conn) + register_composite(x, conn) + obj = x.python_type(ab.python_type("foo", 10), 20) + conn.execute(f"""create table meh (wat "-x-{eur}")""") + conn.execute("insert into meh values (%s)", [obj]) + got = conn.execute("select wat from meh").fetchone()[0] + assert obj == got + + +@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "1", "'"]) +def test_literal_invalid_name(conn, name): + conn.execute("set client_encoding to utf8") + conn.execute( + sql.SQL("create type {name} as (foo text)").format(name=sql.Identifier(name)) + ) + info = CompositeInfo.fetch(conn, sql.Identifier(name).as_string(conn)) + register_composite(info, conn) + obj = info.python_type("hello") + assert sql.Literal(obj).as_string(conn) == f"'(hello)'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format(obj)) + got = cur.fetchone()[0] + assert got == obj + assert type(got) is type(obj) + + +@pytest.mark.parametrize( + "name, attr", + [ + ("a-b", "a_b"), + (f"{eur}", "f_"), + ("üåäö", "üåäö"), + ("order", "order"), + ("1", "f1"), + ], +) +def test_literal_invalid_attr(conn, name, attr): + conn.execute("set client_encoding to utf8") + conn.execute( + sql.SQL("create type test_attr as ({name} text)").format( + name=sql.Identifier(name) + ) + ) + info = CompositeInfo.fetch(conn, "test_attr") + register_composite(info, conn) + obj = info.python_type("hello") + assert getattr(obj, attr) == "hello" + cur = conn.execute(sql.SQL("select {}").format(obj)) + got = cur.fetchone()[0] + assert got == obj + assert type(got) is type(obj) diff --git a/tests/types/test_datetime.py b/tests/types/test_datetime.py new file mode 100644 index 0000000..11fe493 --- /dev/null +++ b/tests/types/test_datetime.py @@ -0,0 +1,813 @@ +import datetime as dt + +import pytest + +from psycopg import DataError, pq, sql +from psycopg.adapt import PyFormat + +crdb_skip_datestyle = pytest.mark.crdb("skip", reason="set datestyle/intervalstyle") +crdb_skip_negative_interval = pytest.mark.crdb("skip", reason="negative interval") +crdb_skip_invalid_tz = pytest.mark.crdb( + "skip", reason="crdb doesn't allow invalid timezones" +) + +datestyles_in = [ + pytest.param(datestyle, marks=crdb_skip_datestyle) + for datestyle in ["DMY", "MDY", "YMD"] +] +datestyles_out = [ + pytest.param(datestyle, marks=crdb_skip_datestyle) + for datestyle in ["ISO", "Postgres", "SQL", "German"] +] + +intervalstyles = [ + pytest.param(datestyle, marks=crdb_skip_datestyle) + for datestyle in ["sql_standard", "postgres", "postgres_verbose", "iso_8601"] +] + + +class TestDate: + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "0001-01-01"), + ("1000,1,1", "1000-01-01"), + ("2000,1,1", "2000-01-01"), + ("2000,12,31", "2000-12-31"), + ("3000,1,1", "3000-01-01"), + ("max", "9999-12-31"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_date(self, conn, val, expr, fmt_in): + val = as_date(val) + cur = conn.cursor() + cur.execute(f"select '{expr}'::date = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + cur.execute( + sql.SQL("select {}::date = {}").format( + sql.Literal(val), sql.Placeholder(format=fmt_in) + ), + (val,), + ) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_dump_date_datestyle(self, conn, datestyle_in): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = ISO,{datestyle_in}") + cur.execute("select 'epoch'::date + 1 = %t", (dt.date(1970, 1, 2),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "0001-01-01"), + ("1000,1,1", "1000-01-01"), + ("2000,1,1", "2000-01-01"), + ("2000,12,31", "2000-12-31"), + ("3000,1,1", "3000-01-01"), + ("max", "9999-12-31"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_date(self, conn, val, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select '{expr}'::date") + assert cur.fetchone()[0] == as_date(val) + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + def test_load_date_datestyle(self, conn, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select '2000-01-02'::date") + assert cur.fetchone()[0] == dt.date(2000, 1, 2) + + @pytest.mark.parametrize("val", ["min", "max"]) + @pytest.mark.parametrize("datestyle_out", datestyles_out) + def test_load_date_overflow(self, conn, val, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %t + %s::int", (as_date(val), -1 if val == "min" else 1)) + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.parametrize("val", ["min", "max"]) + def test_load_date_overflow_binary(self, conn, val): + cur = conn.cursor(binary=True) + cur.execute("select %s + %s::int", (as_date(val), -1 if val == "min" else 1)) + with pytest.raises(DataError): + cur.fetchone()[0] + + overflow_samples = [ + ("-infinity", "date too small"), + ("1000-01-01 BC", "date too small"), + ("10000-01-01", "date too large"), + ("infinity", "date too large"), + ] + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_load_overflow_message(self, conn, datestyle_out, val, msg): + cur = conn.cursor() + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %s::date", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_load_overflow_message_binary(self, conn, val, msg): + cur = conn.cursor(binary=True) + cur.execute("select %s::date", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + def test_infinity_date_example(self, conn): + # NOTE: this is an example in the docs. Make sure it doesn't regress when + # adding binary datetime adapters + from datetime import date + from psycopg.types.datetime import DateLoader, DateDumper + + class InfDateDumper(DateDumper): + def dump(self, obj): + if obj == date.max: + return b"infinity" + else: + return super().dump(obj) + + class InfDateLoader(DateLoader): + def load(self, data): + if data == b"infinity": + return date.max + else: + return super().load(data) + + cur = conn.cursor() + cur.adapters.register_dumper(date, InfDateDumper) + cur.adapters.register_loader("date", InfDateLoader) + + rec = cur.execute( + "SELECT %s::text, %s::text", [date(2020, 12, 31), date.max] + ).fetchone() + assert rec == ("2020-12-31", "infinity") + rec = cur.execute("select '2020-12-31'::date, 'infinity'::date").fetchone() + assert rec == (date(2020, 12, 31), date(9999, 12, 31)) + + +class TestDatetime: + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "0001-01-01 00:00"), + ("258,1,8,1,12,32,358261", "0258-1-8 1:12:32.358261"), + ("1000,1,1,0,0", "1000-01-01 00:00"), + ("2000,1,1,0,0", "2000-01-01 00:00"), + ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"), + ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"), + ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"), + ("2000,1,1,0,0,0,1", "2000-01-01 00:00:00.000001"), + ("2034,02,03,23,34,27,951357", "2034-02-03 23:34:27.951357"), + ("2200,1,1,0,0,0,1", "2200-01-01 00:00:00.000001"), + ("2300,1,1,0,0,0,1", "2300-01-01 00:00:00.000001"), + ("7000,1,1,0,0,0,1", "7000-01-01 00:00:00.000001"), + ("max", "9999-12-31 23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_datetime(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute("set timezone to '+02:00'") + cur.execute(f"select %{fmt_in.value}", (as_dt(val),)) + cur.execute(f"select '{expr}'::timestamp = %{fmt_in.value}", (as_dt(val),)) + cur.execute( + f""" + select '{expr}'::timestamp = %(val){fmt_in.value}, + '{expr}', %(val){fmt_in.value}::text + """, + {"val": as_dt(val)}, + ) + ok, want, got = cur.fetchone() + assert ok, (want, got) + + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_dump_datetime_datestyle(self, conn, datestyle_in): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = ISO, {datestyle_in}") + cur.execute( + "select 'epoch'::timestamp + '1d 3h 4m 5s'::interval = %t", + (dt.datetime(1970, 1, 2, 3, 4, 5),), + ) + assert cur.fetchone()[0] is True + + load_datetime_samples = [ + ("min", "0001-01-01"), + ("1000,1,1", "1000-01-01"), + ("2000,1,1", "2000-01-01"), + ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"), + ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"), + ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"), + ("2000,12,31", "2000-12-31"), + ("3000,1,1", "3000-01-01"), + ("max", "9999-12-31 23:59:59.999999"), + ] + + @pytest.mark.parametrize("val, expr", load_datetime_samples) + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_load_datetime(self, conn, val, expr, datestyle_in, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}") + cur.execute("set timezone to '+02:00'") + cur.execute(f"select '{expr}'::timestamp") + assert cur.fetchone()[0] == as_dt(val) + + @pytest.mark.parametrize("val, expr", load_datetime_samples) + def test_load_datetime_binary(self, conn, val, expr): + cur = conn.cursor(binary=True) + cur.execute("set timezone to '+02:00'") + cur.execute(f"select '{expr}'::timestamp") + assert cur.fetchone()[0] == as_dt(val) + + @pytest.mark.parametrize("val", ["min", "max"]) + @pytest.mark.parametrize("datestyle_out", datestyles_out) + def test_load_datetime_overflow(self, conn, val, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute( + "select %t::timestamp + %s * '1s'::interval", + (as_dt(val), -1 if val == "min" else 1), + ) + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.parametrize("val", ["min", "max"]) + def test_load_datetime_overflow_binary(self, conn, val): + cur = conn.cursor(binary=True) + cur.execute( + "select %t::timestamp + %s * '1s'::interval", + (as_dt(val), -1 if val == "min" else 1), + ) + with pytest.raises(DataError): + cur.fetchone()[0] + + overflow_samples = [ + ("-infinity", "timestamp too small"), + ("1000-01-01 12:00 BC", "timestamp too small"), + ("10000-01-01 12:00", "timestamp too large"), + ("infinity", "timestamp too large"), + ] + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message(self, conn, datestyle_out, val, msg): + cur = conn.cursor() + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %s::timestamp", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message_binary(self, conn, val, msg): + cur = conn.cursor(binary=True) + cur.execute("select %s::timestamp", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @crdb_skip_datestyle + def test_load_all_month_names(self, conn): + cur = conn.cursor(binary=False) + cur.execute("set datestyle = 'Postgres'") + for i in range(12): + d = dt.datetime(2000, i + 1, 15) + cur.execute("select %s", [d]) + assert cur.fetchone()[0] == d + + +class TestDateTimeTz: + @pytest.mark.parametrize( + "val, expr", + [ + ("min~-2", "0001-01-01 00:00-02:00"), + ("min~-12", "0001-01-01 00:00-12:00"), + ( + "258,1,8,1,12,32,358261~1:2:3", + "0258-1-8 1:12:32.358261+01:02:03", + ), + ("1000,1,1,0,0~2", "1000-01-01 00:00+2"), + ("2000,1,1,0,0~2", "2000-01-01 00:00+2"), + ("2000,1,1,0,0~12", "2000-01-01 00:00+12"), + ("2000,1,1,0,0~-12", "2000-01-01 00:00-12"), + ("2000,1,1,0,0~01:02:03", "2000-01-01 00:00+01:02:03"), + ("2000,1,1,0,0~-01:02:03", "2000-01-01 00:00-01:02:03"), + ("2000,12,31,23,59,59,999999~2", "2000-12-31 23:59:59.999999+2"), + ( + "2034,02,03,23,34,27,951357~-4:27", + "2034-02-03 23:34:27.951357-04:27", + ), + ("2300,1,1,0,0,0,1~1", "2300-01-01 00:00:00.000001+1"), + ("3000,1,1,0,0~2", "3000-01-01 00:00+2"), + ("7000,1,1,0,0,0,1~-1:2:3", "7000-01-01 00:00:00.000001-01:02:03"), + ("max~2", "9999-12-31 23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_datetimetz(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute("set timezone to '-02:00'") + cur.execute( + f""" + select '{expr}'::timestamptz = %(val){fmt_in.value}, + '{expr}', %(val){fmt_in.value}::text + """, + {"val": as_dt(val)}, + ) + ok, want, got = cur.fetchone() + assert ok, (want, got) + + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_dump_datetimetz_datestyle(self, conn, datestyle_in): + tzinfo = dt.timezone(dt.timedelta(hours=2)) + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = ISO, {datestyle_in}") + cur.execute("set timezone to '-02:00'") + cur.execute( + "select 'epoch'::timestamptz + '1d 3h 4m 5.678s'::interval = %t", + (dt.datetime(1970, 1, 2, 5, 4, 5, 678000, tzinfo=tzinfo),), + ) + assert cur.fetchone()[0] is True + + load_datetimetz_samples = [ + ("2000,1,1~2", "2000-01-01", "-02:00"), + ("2000,1,2,3,4,5,6~2", "2000-01-02 03:04:05.000006", "-02:00"), + ("2000,1,2,3,4,5,678~1", "2000-01-02 03:04:05.000678", "Europe/Rome"), + ("2000,7,2,3,4,5,678~2", "2000-07-02 03:04:05.000678", "Europe/Rome"), + ("2000,1,2,3,0,0,456789~2", "2000-01-02 03:00:00.456789", "-02:00"), + ("2000,1,2,3,0,0,456789~-2", "2000-01-02 03:00:00.456789", "+02:00"), + ("2000,12,31~2", "2000-12-31", "-02:00"), + ("1900,1,1~05:21:10", "1900-01-01", "Asia/Calcutta"), + ] + + @crdb_skip_datestyle + @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples) + @pytest.mark.parametrize("datestyle_out", ["ISO"]) + def test_load_datetimetz(self, conn, val, expr, timezone, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, DMY") + cur.execute(f"set timezone to '{timezone}'") + got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0] + assert got == as_dt(val) + + @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples) + def test_load_datetimetz_binary(self, conn, val, expr, timezone): + cur = conn.cursor(binary=True) + cur.execute(f"set timezone to '{timezone}'") + got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0] + assert got == as_dt(val) + + @pytest.mark.xfail # parse timezone names + @crdb_skip_datestyle + @pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")]) + @pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"]) + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_load_datetimetz_tzname(self, conn, val, expr, datestyle_in, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}") + cur.execute("set timezone to '-02:00'") + cur.execute(f"select '{expr}'::timestamptz") + assert cur.fetchone()[0] == as_dt(val) + + @pytest.mark.parametrize( + "tzname, expr, tzoff", + [ + ("UTC", "2000-1-1", 0), + ("UTC", "2000-7-1", 0), + ("Europe/Rome", "2000-1-1", 3600), + ("Europe/Rome", "2000-7-1", 7200), + ("Europe/Rome", "1000-1-1", 2996), + pytest.param("NOSUCH0", "2000-1-1", 0, marks=crdb_skip_invalid_tz), + pytest.param("0", "2000-1-1", 0, marks=crdb_skip_invalid_tz), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_datetimetz_tz(self, conn, fmt_out, tzname, expr, tzoff): + conn.execute("select set_config('TimeZone', %s, true)", [tzname]) + cur = conn.cursor(binary=fmt_out) + ts = cur.execute("select %s::timestamptz", [expr]).fetchone()[0] + assert ts.utcoffset().total_seconds() == tzoff + + @pytest.mark.parametrize( + "val, type", + [ + ("2000,1,2,3,4,5,6", "timestamp"), + ("2000,1,2,3,4,5,6~0", "timestamptz"), + ("2000,1,2,3,4,5,6~2", "timestamptz"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_datetime_tz_or_not_tz(self, conn, val, type, fmt_in): + val = as_dt(val) + cur = conn.cursor() + cur.execute( + f""" + select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value} + """, + [val, type, val], + ) + rec = cur.fetchone() + assert rec[0] is True, type + assert rec[1] == val + + @pytest.mark.crdb_skip("copy") + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '2000-01-01 01:02:03.123456-10:20'::timestamptz, + '11111111'::int4 + ) to stdout + """ + ) as copy: + copy.set_types(["timestamptz", "int4"]) + rec = copy.read_row() + + tz = dt.timezone(-dt.timedelta(hours=10, minutes=20)) + want = dt.datetime(2000, 1, 1, 1, 2, 3, 123456, tzinfo=tz) + assert rec[0] == want + assert rec[1] == 11111111 + + overflow_samples = [ + ("-infinity", "timestamp too small"), + ("1000-01-01 12:00+00 BC", "timestamp too small"), + ("10000-01-01 12:00+00", "timestamp too large"), + ("infinity", "timestamp too large"), + ] + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message(self, conn, datestyle_out, val, msg): + cur = conn.cursor() + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %s::timestamptz", (val,)) + if datestyle_out == "ISO": + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + else: + with pytest.raises(NotImplementedError): + cur.fetchone()[0] + + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message_binary(self, conn, val, msg): + cur = conn.cursor(binary=True) + cur.execute("select %s::timestamptz", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @pytest.mark.parametrize( + "valname, tzval, tzname", + [ + ("max", "-06", "America/Chicago"), + ("min", "+09:18:59", "Asia/Tokyo"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_max_with_timezone(self, conn, fmt_out, valname, tzval, tzname): + # This happens e.g. in Django when it caches forever. + # e.g. see Django test cache.tests.DBCacheTests.test_forever_timeout + val = getattr(dt.datetime, valname).replace(microsecond=0) + tz = dt.timezone(as_tzoffset(tzval)) + want = val.replace(tzinfo=tz) + + conn.execute("set timezone to '%s'" % tzname) + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::timestamptz", [str(val) + tzval]) + got = cur.fetchone()[0] + + assert got == want + + extra = "1 day" if valname == "max" else "-1 day" + with pytest.raises(DataError): + cur.execute( + "select %s::timestamptz + %s::interval", + [str(val) + tzval, extra], + ) + got = cur.fetchone()[0] + + +class TestTime: + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "00:00"), + ("10,20,30,40", "10:20:30.000040"), + ("max", "23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_time(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute( + f""" + select '{expr}'::time = %(val){fmt_in.value}, + '{expr}'::time::text, %(val){fmt_in.value}::text + """, + {"val": as_time(val)}, + ) + ok, want, got = cur.fetchone() + assert ok, (got, want) + + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "00:00"), + ("1,2", "01:02"), + ("10,20", "10:20"), + ("10,20,30", "10:20:30"), + ("10,20,30,40", "10:20:30.000040"), + ("max", "23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_time(self, conn, val, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select '{expr}'::time") + assert cur.fetchone()[0] == as_time(val) + + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_time_24(self, conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select '24:00'::time") + with pytest.raises(DataError): + cur.fetchone()[0] + + +class TestTimeTz: + @pytest.mark.parametrize( + "val, expr", + [ + ("min~-10", "00:00-10:00"), + ("min~+12", "00:00+12:00"), + ("10,20,30,40~-2", "10:20:30.000040-02:00"), + ("10,20,30,40~0", "10:20:30.000040Z"), + ("10,20,30,40~+2:30", "10:20:30.000040+02:30"), + ("max~-12", "23:59:59.999999-12:00"), + ("max~+12", "23:59:59.999999+12:00"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_timetz(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute("set timezone to '-02:00'") + cur.execute(f"select '{expr}'::timetz = %{fmt_in.value}", (as_time(val),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( + "val, expr, timezone", + [ + ("0,0~-12", "00:00", "12:00"), + ("0,0~12", "00:00", "-12:00"), + ("3,4,5,6~2", "03:04:05.000006", "-02:00"), + ("3,4,5,6~7:8", "03:04:05.000006", "-07:08"), + ("3,0,0,456789~2", "03:00:00.456789", "-02:00"), + ("3,0,0,456789~-2", "03:00:00.456789", "+02:00"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_timetz(self, conn, val, timezone, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"set timezone to '{timezone}'") + cur.execute(f"select '{expr}'::timetz") + assert cur.fetchone()[0] == as_time(val) + + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_timetz_24(self, conn, fmt_out): + cur = conn.cursor() + cur.execute("select '24:00'::timetz") + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.parametrize( + "val, type", + [ + ("3,4,5,6", "time"), + ("3,4,5,6~0", "timetz"), + ("3,4,5,6~2", "timetz"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_time_tz_or_not_tz(self, conn, val, type, fmt_in): + val = as_time(val) + cur = conn.cursor() + cur.execute( + f""" + select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value} + """, + [val, type, val], + ) + rec = cur.fetchone() + assert rec[0] is True, type + assert rec[1] == val + + @pytest.mark.crdb_skip("copy") + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '01:02:03.123456-10:20'::timetz, + '11111111'::int4 + ) to stdout + """ + ) as copy: + copy.set_types(["timetz", "int4"]) + rec = copy.read_row() + + tz = dt.timezone(-dt.timedelta(hours=10, minutes=20)) + want = dt.time(1, 2, 3, 123456, tzinfo=tz) + assert rec[0] == want + assert rec[1] == 11111111 + + +class TestInterval: + dump_timedelta_samples = [ + ("min", "-999999999 days"), + ("1d", "1 day"), + pytest.param("-1d", "-1 day", marks=crdb_skip_negative_interval), + ("1s", "1 s"), + pytest.param("-1s", "-1 s", marks=crdb_skip_negative_interval), + pytest.param("-1m", "-0.000001 s", marks=crdb_skip_negative_interval), + ("1m", "0.000001 s"), + ("max", "999999999 days 23:59:59.999999"), + ] + + @pytest.mark.parametrize("val, expr", dump_timedelta_samples) + @pytest.mark.parametrize("intervalstyle", intervalstyles) + def test_dump_interval(self, conn, val, expr, intervalstyle): + cur = conn.cursor() + cur.execute(f"set IntervalStyle to '{intervalstyle}'") + cur.execute(f"select '{expr}'::interval = %t", (as_td(val),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize("val, expr", dump_timedelta_samples) + def test_dump_interval_binary(self, conn, val, expr): + cur = conn.cursor() + cur.execute(f"select '{expr}'::interval = %b", (as_td(val),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( + "val, expr", + [ + ("1s", "1 sec"), + ("-1s", "-1 sec"), + ("60s", "1 min"), + ("3600s", "1 hour"), + ("1s,1000m", "1.001 sec"), + ("1s,1m", "1.000001 sec"), + ("1d", "1 day"), + ("-10d", "-10 day"), + ("1d,1s,1m", "1 day 1.000001 sec"), + ("-86399s,-999999m", "-23:59:59.999999"), + ("-3723s,-400000m", "-1:2:3.4"), + ("3723s,400000m", "1:2:3.4"), + ("86399s,999999m", "23:59:59.999999"), + ("30d", "30 day"), + ("365d", "1 year"), + ("-365d", "-1 year"), + ("-730d", "-2 years"), + ("1460d", "4 year"), + ("30d", "1 month"), + ("-30d", "-1 month"), + ("60d", "2 month"), + ("-90d", "-3 month"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_interval(self, conn, val, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select '{expr}'::interval") + assert cur.fetchone()[0] == as_td(val) + + @crdb_skip_datestyle + @pytest.mark.xfail # weird interval outputs + @pytest.mark.parametrize("val, expr", [("1d,1s", "1 day 1 sec")]) + @pytest.mark.parametrize( + "intervalstyle", + ["sql_standard", "postgres_verbose", "iso_8601"], + ) + def test_load_interval_intervalstyle(self, conn, val, expr, intervalstyle): + cur = conn.cursor(binary=False) + cur.execute(f"set IntervalStyle to '{intervalstyle}'") + cur.execute(f"select '{expr}'::interval") + assert cur.fetchone()[0] == as_td(val) + + @pytest.mark.parametrize("fmt_out", pq.Format) + @pytest.mark.parametrize("val", ["min", "max"]) + def test_load_interval_overflow(self, conn, val, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute( + "select %s + %s * '1s'::interval", + (as_td(val), -1 if val == "min" else 1), + ) + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.crdb_skip("copy") + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '1 days +00:00:01.000001'::interval, + 'foo bar'::text + ) to stdout + """ + ) as copy: + copy.set_types(["interval", "text"]) + rec = copy.read_row() + + want = dt.timedelta(days=1, seconds=1, microseconds=1) + assert rec[0] == want + assert rec[1] == "foo bar" + + +# +# Support +# + + +def as_date(s): + return dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s) + + +def as_time(s): + if "~" in s: + s, off = s.split("~") + else: + off = None + + if "," in s: + rv = dt.time(*map(int, s.split(","))) # type: ignore[arg-type] + else: + rv = getattr(dt.time, s) + if off: + rv = rv.replace(tzinfo=as_tzinfo(off)) + + return rv + + +def as_dt(s): + if "~" not in s: + return as_naive_dt(s) + + s, off = s.split("~") + rv = as_naive_dt(s) + off = as_tzoffset(off) + rv = (rv - off).replace(tzinfo=dt.timezone.utc) + return rv + + +def as_naive_dt(s): + if "," in s: + rv = dt.datetime(*map(int, s.split(","))) # type: ignore[arg-type] + else: + rv = getattr(dt.datetime, s) + + return rv + + +def as_tzoffset(s): + if s.startswith("-"): + mul = -1 + s = s[1:] + else: + mul = 1 + + fields = ("hours", "minutes", "seconds") + return mul * dt.timedelta(**dict(zip(fields, map(int, s.split(":"))))) + + +def as_tzinfo(s): + off = as_tzoffset(s) + return dt.timezone(off) + + +def as_td(s): + if s in ("min", "max"): + return getattr(dt.timedelta, s) + + suffixes = {"d": "days", "s": "seconds", "m": "microseconds"} + kwargs = {} + for part in s.split(","): + kwargs[suffixes[part[-1]]] = int(part[:-1]) + + return dt.timedelta(**kwargs) diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py new file mode 100644 index 0000000..8dfb6d4 --- /dev/null +++ b/tests/types/test_enum.py @@ -0,0 +1,363 @@ +from enum import Enum, auto + +import pytest + +from psycopg import pq, sql, errors as e +from psycopg.adapt import PyFormat +from psycopg.types import TypeInfo +from psycopg.types.enum import EnumInfo, register_enum + +from ..fix_crdb import crdb_encoding + + +class PureTestEnum(Enum): + FOO = auto() + BAR = auto() + BAZ = auto() + + +class StrTestEnum(str, Enum): + ONE = "ONE" + TWO = "TWO" + THREE = "THREE" + + +NonAsciiEnum = Enum( + "NonAsciiEnum", + {"X\xe0": "x\xe0", "X\xe1": "x\xe1", "COMMA": "foo,bar"}, + type=str, +) + + +class IntTestEnum(int, Enum): + ONE = 1 + TWO = 2 + THREE = 3 + + +enum_cases = [PureTestEnum, StrTestEnum, IntTestEnum] +encodings = ["utf8", crdb_encoding("latin1")] + + +@pytest.fixture(scope="session", autouse=True) +def make_test_enums(request, svcconn): + for enum in enum_cases + [NonAsciiEnum]: + ensure_enum(enum, svcconn) + + +def ensure_enum(enum, conn): + name = enum.__name__.lower() + labels = list(enum.__members__) + conn.execute( + sql.SQL( + """ + drop type if exists {name}; + create type {name} as enum ({labels}); + """ + ).format(name=sql.Identifier(name), labels=sql.SQL(",").join(labels)) + ) + return name, enum, labels + + +def test_fetch_info(conn): + info = EnumInfo.fetch(conn, "StrTestEnum") + assert info.name == "strtestenum" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.labels) == len(StrTestEnum) + assert info.labels == list(StrTestEnum.__members__) + + +@pytest.mark.asyncio +async def test_fetch_info_async(aconn): + info = await EnumInfo.fetch(aconn, "PureTestEnum") + assert info.name == "puretestenum" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.labels) == len(PureTestEnum) + assert info.labels == list(PureTestEnum.__members__) + + +def test_register_makes_a_type(conn): + info = EnumInfo.fetch(conn, "IntTestEnum") + assert info + assert info.enum is None + register_enum(info, context=conn) + assert info.enum is not None + assert [e.name for e in info.enum] == list(IntTestEnum.__members__) + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum=enum) + + for label in info.labels: + cur = conn.execute( + f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out + ) + assert cur.fetchone()[0] == enum[label] + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out): + enum = NonAsciiEnum + conn.execute(f"set client_encoding to {encoding}") + + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum=enum) + + for label in info.labels: + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out + ) + assert cur.fetchone()[0] == enum[label] + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + conn.execute("set client_encoding to sql_ascii") + + for label in info.labels: + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out + ) + assert cur.fetchone()[0] == enum[label] + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_dumper(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + for item in enum: + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == item + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out): + enum = NonAsciiEnum + conn.execute(f"set client_encoding to {encoding}") + + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + for item in enum: + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == item + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + conn.execute("set client_encoding to sql_ascii") + + for item in enum: + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == item + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_dumper(conn, enum, fmt_in, fmt_out): + for item in enum: + if enum is PureTestEnum: + want = item.name + else: + want = item.value + + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == want + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") + for item in NonAsciiEnum: + cur = conn.execute(f"select %{fmt_in.value}", [item.value], binary=fmt_out) + assert cur.fetchone()[0] == item.value + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_loader(conn, enum, fmt_in, fmt_out): + for label in enum.__members__: + cur = conn.execute( + f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out + ) + want = enum[label].name + if fmt_out == pq.Format.BINARY: + want = want.encode() + assert cur.fetchone()[0] == want + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") + + for label in NonAsciiEnum.__members__: + cur = conn.execute( + f"select %{fmt_in.value}::nonasciienum", [label], binary=fmt_out + ) + if fmt_out == pq.Format.TEXT: + assert cur.fetchone()[0] == label + else: + assert cur.fetchone()[0] == label.encode(encoding) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array_loader(conn, fmt_in, fmt_out): + enum = PureTestEnum + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + labels = list(enum.__members__) + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out + ) + assert cur.fetchone()[0] == list(enum) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array_dumper(conn, fmt_in, fmt_out): + enum = StrTestEnum + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + cur = conn.execute(f"select %{fmt_in.value}::text[]", [list(enum)], binary=fmt_out) + assert cur.fetchone()[0] == list(enum.__members__) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_array_loader(conn, fmt_in, fmt_out): + enum = IntTestEnum + info = TypeInfo.fetch(conn, enum.__name__) + info.register(conn) + labels = list(enum.__members__) + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out + ) + if fmt_out == pq.Format.TEXT: + assert cur.fetchone()[0] == labels + else: + assert cur.fetchone()[0] == [item.encode() for item in labels] + + +def test_enum_error(conn): + conn.autocommit = True + + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, StrTestEnum) + + with pytest.raises(e.DataError): + conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone() + with pytest.raises(e.DataError): + conn.execute("select 'BAR'::puretestenum").fetchone() + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize( + "mapping", + [ + {StrTestEnum.ONE: "FOO", StrTestEnum.TWO: "BAR", StrTestEnum.THREE: "BAZ"}, + [ + (StrTestEnum.ONE, "FOO"), + (StrTestEnum.TWO, "BAR"), + (StrTestEnum.THREE, "BAZ"), + ], + ], +) +def test_remap(conn, fmt_in, fmt_out, mapping): + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, StrTestEnum, mapping=mapping) + + for member, label in [("ONE", "FOO"), ("TWO", "BAR"), ("THREE", "BAZ")]: + cur = conn.execute(f"select %{fmt_in.value}::text", [StrTestEnum[member]]) + assert cur.fetchone()[0] == label + cur = conn.execute(f"select '{label}'::puretestenum", binary=fmt_out) + assert cur.fetchone()[0] is StrTestEnum[member] + + +def test_remap_rename(conn): + enum = Enum("RenamedEnum", "FOO BAR QUX") + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, enum, mapping={enum.QUX: "BAZ"}) + + for member, label in [("FOO", "FOO"), ("BAR", "BAR"), ("QUX", "BAZ")]: + cur = conn.execute("select %s::text", [enum[member]]) + assert cur.fetchone()[0] == label + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[member] + + +def test_remap_more_python(conn): + enum = Enum("LargerEnum", "FOO BAR BAZ QUX QUUX QUUUX") + info = EnumInfo.fetch(conn, "puretestenum") + mapping = {enum[m]: "BAZ" for m in ["QUX", "QUUX", "QUUUX"]} + register_enum(info, conn, enum, mapping=mapping) + + for member, label in [("FOO", "FOO"), ("BAZ", "BAZ"), ("QUUUX", "BAZ")]: + cur = conn.execute("select %s::text", [enum[member]]) + assert cur.fetchone()[0] == label + + for member, label in [("FOO", "FOO"), ("QUUUX", "BAZ")]: + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[member] + + +def test_remap_more_postgres(conn): + enum = Enum("SmallerEnum", "FOO") + info = EnumInfo.fetch(conn, "puretestenum") + mapping = [(enum.FOO, "BAR"), (enum.FOO, "BAZ")] + register_enum(info, conn, enum, mapping=mapping) + + cur = conn.execute("select %s::text", [enum.FOO]) + assert cur.fetchone()[0] == "BAZ" + + for label in PureTestEnum.__members__: + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum.FOO + + +def test_remap_by_value(conn): + enum = Enum( # type: ignore + "ByValue", + {m.lower(): m for m in PureTestEnum.__members__}, + ) + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, enum, mapping={m: m.value for m in enum}) + + for label in PureTestEnum.__members__: + cur = conn.execute("select %s::text", [enum[label.lower()]]) + assert cur.fetchone()[0] == label + + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[label.lower()] diff --git a/tests/types/test_hstore.py b/tests/types/test_hstore.py new file mode 100644 index 0000000..5142d58 --- /dev/null +++ b/tests/types/test_hstore.py @@ -0,0 +1,107 @@ +import pytest + +import psycopg +from psycopg.types import TypeInfo +from psycopg.types.hstore import HstoreLoader, register_hstore + +pytestmark = pytest.mark.crdb_skip("hstore") + + +@pytest.mark.parametrize( + "s, d", + [ + ("", {}), + ('"a"=>"1", "b"=>"2"', {"a": "1", "b": "2"}), + ('"a" => "1" , "b" => "2"', {"a": "1", "b": "2"}), + ('"a"=>NULL, "b"=>"2"', {"a": None, "b": "2"}), + (r'"a"=>"\"", "\""=>"2"', {"a": '"', '"': "2"}), + ('"a"=>"\'", "\'"=>"2"', {"a": "'", "'": "2"}), + ('"a"=>"1", "b"=>NULL', {"a": "1", "b": None}), + (r'"a\\"=>"1"', {"a\\": "1"}), + (r'"a\""=>"1"', {'a"': "1"}), + (r'"a\\\""=>"1"', {r"a\"": "1"}), + (r'"a\\\\\""=>"1"', {r'a\\"': "1"}), + ('"\xe8"=>"\xe0"', {"\xe8": "\xe0"}), + ], +) +def test_parse_ok(s, d): + loader = HstoreLoader(0, None) + assert loader.load(s.encode()) == d + + +@pytest.mark.parametrize( + "s", + [ + "a", + '"a"', + r'"a\\""=>"1"', + r'"a\\\\""=>"1"', + '"a=>"1"', + '"a"=>"1", "b"=>NUL', + ], +) +def test_parse_bad(s): + with pytest.raises(psycopg.DataError): + loader = HstoreLoader(0, None) + loader.load(s.encode()) + + +def test_register_conn(hstore, conn): + info = TypeInfo.fetch(conn, "hstore") + register_hstore(info, conn) + assert conn.adapters.types[info.oid].name == "hstore" + + cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + assert cur.fetchone() == (None, {}, {"a": "b"}) + + +def test_register_curs(hstore, conn): + info = TypeInfo.fetch(conn, "hstore") + cur = conn.cursor() + register_hstore(info, cur) + assert conn.adapters.types.get(info.oid) is None + assert cur.adapters.types[info.oid].name == "hstore" + + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + assert cur.fetchone() == (None, {}, {"a": "b"}) + + +def test_register_globally(conn_cls, hstore, dsn, svcconn, global_adapters): + info = TypeInfo.fetch(svcconn, "hstore") + register_hstore(info) + assert psycopg.adapters.types[info.oid].name == "hstore" + + assert svcconn.adapters.types.get(info.oid) is None + conn = conn_cls.connect(dsn) + assert conn.adapters.types[info.oid].name == "hstore" + + cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + assert cur.fetchone() == (None, {}, {"a": "b"}) + conn.close() + + +ab = list(map(chr, range(32, 128))) +samp = [ + {}, + {"a": "b", "c": None}, + dict(zip(ab, ab)), + {"".join(ab): "".join(ab)}, +] + + +@pytest.mark.parametrize("d", samp) +def test_roundtrip(hstore, conn, d): + register_hstore(TypeInfo.fetch(conn, "hstore"), conn) + d1 = conn.execute("select %s", [d]).fetchone()[0] + assert d == d1 + + +def test_roundtrip_array(hstore, conn): + register_hstore(TypeInfo.fetch(conn, "hstore"), conn) + samp1 = conn.execute("select %s", (samp,)).fetchone()[0] + assert samp1 == samp + + +def test_no_info_error(conn): + with pytest.raises(TypeError, match="hstore.*extension"): + register_hstore(None, conn) # type: ignore[arg-type] diff --git a/tests/types/test_json.py b/tests/types/test_json.py new file mode 100644 index 0000000..50e8ce3 --- /dev/null +++ b/tests/types/test_json.py @@ -0,0 +1,182 @@ +import json +from copy import deepcopy + +import pytest + +import psycopg.types +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat +from psycopg.types.json import set_json_dumps, set_json_loads + +samples = [ + "null", + "true", + '"te\'xt"', + '"\\u00e0\\u20ac"', + "123", + "123.45", + '["a", 100]', + '{"a": 100}', +] + + +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_wrapper_regtype(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + cur = conn.cursor() + cur.execute( + f"select pg_typeof(%{fmt_in.value})::regtype = %s::regtype", + (wrapper([]), wrapper.__name__.lower()), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump(conn, val, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = json.loads(val) + cur = conn.cursor() + cur.execute( + f"select %{fmt_in.value}::text = %s::{wrapper.__name__.lower()}::text", + (wrapper(obj), val), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.crdb_skip("json array") +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_array_dump(conn, val, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = json.loads(val) + cur = conn.cursor() + cur.execute( + f"select %{fmt_in.value}::text = array[%s::{wrapper.__name__.lower()}]::text", + ([wrapper(obj)], val), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("jtype", ["json", "jsonb"]) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load(conn, val, jtype, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select %s::{jtype}", (val,)) + assert cur.fetchone()[0] == json.loads(val) + + +@pytest.mark.crdb_skip("json array") +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("jtype", ["json", "jsonb"]) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_array(conn, val, jtype, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select array[%s::{jtype}]", (val,)) + assert cur.fetchone()[0] == [json.loads(val)] + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("jtype", ["json", "jsonb"]) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_copy(conn, val, jtype, fmt_out): + cur = conn.cursor() + stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format( + val, sql.Identifier(jtype), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([jtype]) + (got,) = copy.read_row() + + assert got == json.loads(val) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +def test_dump_customise(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = {"foo": "bar"} + cur = conn.cursor() + + set_json_dumps(my_dumps) + try: + cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj),)) + assert cur.fetchone()[0] is True + finally: + set_json_dumps(json.dumps) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +def test_dump_customise_context(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = {"foo": "bar"} + cur1 = conn.cursor() + cur2 = conn.cursor() + + set_json_dumps(my_dumps, cur2) + cur1.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),)) + assert cur1.fetchone()[0] is None + cur2.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),)) + assert cur2.fetchone()[0] == "qux" + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +def test_dump_customise_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = {"foo": "bar"} + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj, my_dumps),)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("binary", [True, False]) +@pytest.mark.parametrize("pgtype", ["json", "jsonb"]) +def test_load_customise(conn, binary, pgtype): + cur = conn.cursor(binary=binary) + + set_json_loads(my_loads) + try: + cur.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""") + obj = cur.fetchone()[0] + assert obj["foo"] == "bar" + assert obj["answer"] == 42 + finally: + set_json_loads(json.loads) + + +@pytest.mark.parametrize("binary", [True, False]) +@pytest.mark.parametrize("pgtype", ["json", "jsonb"]) +def test_load_customise_context(conn, binary, pgtype): + cur1 = conn.cursor(binary=binary) + cur2 = conn.cursor(binary=binary) + + set_json_loads(my_loads, cur2) + cur1.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""") + got = cur1.fetchone()[0] + assert got["foo"] == "bar" + assert "answer" not in got + + cur2.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""") + got = cur2.fetchone()[0] + assert got["foo"] == "bar" + assert got["answer"] == 42 + + +def my_dumps(obj): + obj = deepcopy(obj) + obj["baz"] = "qux" + return json.dumps(obj) + + +def my_loads(data): + obj = json.loads(data) + obj["answer"] = 42 + return obj diff --git a/tests/types/test_multirange.py b/tests/types/test_multirange.py new file mode 100644 index 0000000..2ab5152 --- /dev/null +++ b/tests/types/test_multirange.py @@ -0,0 +1,434 @@ +import pickle +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg import pq, sql +from psycopg import errors as e +from psycopg.adapt import PyFormat +from psycopg.types.range import Range +from psycopg.types import multirange +from psycopg.types.multirange import Multirange, MultirangeInfo +from psycopg.types.multirange import register_multirange + +from ..utils import eur +from .test_range import create_test_range + +pytestmark = [ + pytest.mark.pg(">= 14"), + pytest.mark.crdb_skip("range"), +] + + +class TestMultirangeObject: + def test_empty(self): + mr = Multirange[int]() + assert not mr + assert len(mr) == 0 + + mr = Multirange([]) + assert not mr + assert len(mr) == 0 + + def test_sequence(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + assert mr + assert len(mr) == 3 + assert mr[2] == Range(50, 60) + assert mr[-2] == Range(30, 40) + + def test_bad_type(self): + with pytest.raises(TypeError): + Multirange(Range(10, 20)) # type: ignore[arg-type] + + with pytest.raises(TypeError): + Multirange([10]) # type: ignore[arg-type] + + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + + with pytest.raises(TypeError): + mr[0] = "foo" # type: ignore[call-overload] + + with pytest.raises(TypeError): + mr[0:1] = "foo" # type: ignore[assignment] + + with pytest.raises(TypeError): + mr[0:1] = ["foo"] # type: ignore[list-item] + + with pytest.raises(TypeError): + mr.insert(0, "foo") # type: ignore[arg-type] + + def test_setitem(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + mr[1] = Range(31, 41) + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)]) + + def test_setitem_slice(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + mr[1:3] = [Range(31, 41), Range(51, 61)] + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(51, 61)]) + + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + with pytest.raises(TypeError, match="can only assign an iterable"): + mr[1:3] = Range(31, 41) # type: ignore[call-overload] + + mr[1:3] = [Range(31, 41)] + assert mr == Multirange([Range(10, 20), Range(31, 41)]) + + def test_delitem(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + del mr[1] + assert mr == Multirange([Range(10, 20), Range(50, 60)]) + + del mr[-2] + assert mr == Multirange([Range(50, 60)]) + + def test_insert(self): + mr = Multirange([Range(10, 20), Range(50, 60)]) + mr.insert(1, Range(31, 41)) + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)]) + + def test_relations(self): + mr1 = Multirange([Range(10, 20), Range(30, 40)]) + mr2 = Multirange([Range(11, 20), Range(30, 40)]) + mr3 = Multirange([Range(9, 20), Range(30, 40)]) + assert mr1 <= mr1 + assert not mr1 < mr1 + assert mr1 >= mr1 + assert not mr1 > mr1 + assert mr1 < mr2 + assert mr1 <= mr2 + assert mr1 > mr3 + assert mr1 >= mr3 + assert mr1 != mr2 + assert not mr1 == mr2 + + def test_pickling(self): + r = Multirange([Range(0, 4)]) + assert pickle.loads(pickle.dumps(r)) == r + + def test_str(self): + mr = Multirange([Range(10, 20), Range(30, 40)]) + assert str(mr) == "{[10, 20), [30, 40)}" + + def test_repr(self): + mr = Multirange([Range(10, 20), Range(30, 40)]) + expected = "Multirange([Range(10, 20, '[)'), Range(30, 40, '[)')])" + assert repr(mr) == expected + + +tzinfo = dt.timezone(dt.timedelta(hours=2)) + +samples = [ + ("int4multirange", [Range(None, None, "()")]), + ("int4multirange", [Range(10, 20), Range(30, 40)]), + ("int8multirange", [Range(None, None, "()")]), + ("int8multirange", [Range(10, 20), Range(30, 40)]), + ( + "nummultirange", + [ + Range(None, Decimal(-100)), + Range(Decimal(100), Decimal("100.123")), + ], + ), + ( + "datemultirange", + [Range(dt.date(2000, 1, 1), dt.date(2020, 1, 1))], + ), + ( + "tsmultirange", + [ + Range( + dt.datetime(2000, 1, 1, 00, 00), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999), + ) + ], + ), + ( + "tstzmultirange", + [ + Range( + dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + ), + Range( + dt.datetime(2030, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2040, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + ), + ], + ), +] + +mr_names = """ + int4multirange int8multirange nummultirange + datemultirange tsmultirange tstzmultirange""".split() + +mr_classes = """ + Int4Multirange Int8Multirange NumericMultirange + DateMultirange TimestampMultirange TimestamptzMultirange""".split() + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty(conn, pgtype, fmt_in): + mr = Multirange() # type: ignore[var-annotated] + cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in.value}", (mr,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in): + dumper = getattr(multirange, wrapper + "Dumper") + wrapper = getattr(multirange, wrapper) + mr = wrapper() + rec = conn.execute( + f""" + select '{{}}' = %(mr){fmt_in.value}, + %(mr){fmt_in.value}::text, + pg_typeof(%(mr){fmt_in.value})::oid + """, + {"mr": mr}, + ).fetchone() + assert rec[0] is True, rec[1] + assert rec[2] == dumper.oid + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize( + "fmt_in", + [ + PyFormat.AUTO, + PyFormat.TEXT, + # There are many ways to work around this (use text, use a cast on the + # placeholder, use specific Range subclasses). + pytest.param( + PyFormat.BINARY, + marks=pytest.mark.xfail( + reason="can't dump array of untypes binary multirange without cast" + ), + ), + ], +) +def test_dump_builtin_array(conn, pgtype, fmt_in): + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.execute( + f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in.value}", + ([mr1, mr2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in): + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.execute( + f""" + select array['{{}}'::{pgtype}, + '{{(,)}}'::{pgtype}] = %{fmt_in.value}::{pgtype}[] + """, + ([mr1, mr2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(multirange, wrapper) + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.execute( + f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in.value}""", ([mr1, mr2],) + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype, ranges", samples) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_multirange(conn, pgtype, ranges, fmt_in): + mr = Multirange(ranges) + rname = pgtype.replace("multi", "") + phs = ", ".join([f"%s::{rname}"] * len(ranges)) + cur = conn.execute(f"select {pgtype}({phs}) = %{fmt_in.value}", ranges + [mr]) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_empty(conn, pgtype, fmt_out): + mr = Multirange() # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select '{{}}'::{pgtype}").fetchone() + assert type(got) is Multirange + assert got == mr + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_array(conn, pgtype, fmt_out): + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute( + f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}]" + ).fetchone() + assert got == [mr1, mr2] + + +@pytest.mark.parametrize("pgtype, ranges", samples) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_range(conn, pgtype, ranges, fmt_out): + mr = Multirange(ranges) + rname = pgtype.replace("multi", "") + phs = ", ".join([f"%s::{rname}"] * len(ranges)) + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select {pgtype}({phs})", ranges) + assert cur.fetchone()[0] == mr + + +@pytest.mark.parametrize( + "min, max, bounds", + [ + ("2000,1,1", "2001,1,1", "[)"), + ("2000,1,1", None, "[)"), + (None, "2001,1,1", "()"), + (None, None, "()"), + (None, None, "empty"), + ], +) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in(conn, min, max, bounds, format): + cur = conn.cursor() + cur.execute("create table copymr (id serial primary key, mr datemultirange)") + + if bounds != "empty": + min = dt.date(*map(int, min.split(","))) if min else None + max = dt.date(*map(int, max.split(","))) if max else None + r = Range[dt.date](min, max, bounds) + else: + r = Range(empty=True) + + mr = Multirange([r]) + try: + with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy: + copy.write_row([mr]) + except e.InternalError_: + if not min and not max and format == pq.Format.BINARY: + pytest.xfail("TODO: add annotation to dump multirange with no type info") + else: + raise + + rec = cur.execute("select mr from copymr order by id").fetchone() + if not r.isempty: + assert rec[0] == mr + else: + assert rec[0] == Multirange() + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_wrappers(conn, wrapper, format): + cur = conn.cursor() + cur.execute("create table copymr (id serial primary key, mr datemultirange)") + + mr = getattr(multirange, wrapper)() + + with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy: + copy.write_row([mr]) + + rec = cur.execute("select mr from copymr order by id").fetchone() + assert rec[0] == mr + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_set_type(conn, pgtype, format): + cur = conn.cursor() + cur.execute(f"create table copymr (id serial primary key, mr {pgtype})") + + mr = Multirange() # type: ignore[var-annotated] + + with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy: + copy.set_types([pgtype]) + copy.write_row([mr]) + + rec = cur.execute("select mr from copymr order by id").fetchone() + assert rec[0] == mr + + +@pytest.fixture(scope="session") +def testmr(svcconn): + create_test_range(svcconn) + + +fetch_cases = [ + ("testmultirange", "text"), + ("testschema.testmultirange", "float8"), + (sql.Identifier("testmultirange"), "text"), + (sql.Identifier("testschema", "testmultirange"), "float8"), +] + + +@pytest.mark.parametrize("name, subtype", fetch_cases) +def test_fetch_info(conn, testmr, name, subtype): + info = MultirangeInfo.fetch(conn, name) + assert info.name == "testmultirange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == conn.adapters.types[subtype].oid + + +def test_fetch_info_not_found(conn): + assert MultirangeInfo.fetch(conn, "nosuchrange") is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, subtype", fetch_cases) +async def test_fetch_info_async(aconn, testmr, name, subtype): # noqa: F811 + info = await MultirangeInfo.fetch(aconn, name) + assert info.name == "testmultirange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == aconn.adapters.types[subtype].oid + + +@pytest.mark.asyncio +async def test_fetch_info_not_found_async(aconn): + assert await MultirangeInfo.fetch(aconn, "nosuchrange") is None + + +def test_dump_custom_empty(conn, testmr): + info = MultirangeInfo.fetch(conn, "testmultirange") + register_multirange(info, conn) + + r = Multirange() # type: ignore[var-annotated] + cur = conn.execute("select '{}'::testmultirange = %s", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_custom_empty(conn, testmr, fmt_out): + info = MultirangeInfo.fetch(conn, "testmultirange") + register_multirange(info, conn) + + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute("select '{}'::testmultirange").fetchone() + assert isinstance(got, Multirange) + assert not got + + +@pytest.mark.parametrize("name", ["a-b", f"{eur}"]) +def test_literal_invalid_name(conn, name): + conn.execute("set client_encoding to utf8") + conn.execute(f'create type "{name}" as range (subtype = text)') + info = MultirangeInfo.fetch(conn, f'"{name}_multirange"') + register_multirange(info, conn) + obj = Multirange([Range("a", "z", "[]")]) + assert sql.Literal(obj).as_string(conn) == f"'{{[a,z]}}'::\"{name}_multirange\"" + cur = conn.execute(sql.SQL("select {}").format(obj)) + assert cur.fetchone()[0] == obj diff --git a/tests/types/test_net.py b/tests/types/test_net.py new file mode 100644 index 0000000..8739398 --- /dev/null +++ b/tests/types/test_net.py @@ -0,0 +1,135 @@ +import ipaddress + +import pytest + +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat + +crdb_skip_cidr = pytest.mark.crdb_skip("cidr") + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("val", ["192.168.0.1", "2001:db8::"]) +def test_address_dump(conn, fmt_in, val): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = %s::inet", (ipaddress.ip_address(val), val)) + assert cur.fetchone()[0] is True + cur.execute( + f"select %{fmt_in.value} = array[null, %s]::inet[]", + ([None, ipaddress.ip_interface(val)], val), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/128"]) +def test_interface_dump(conn, fmt_in, val): + cur = conn.cursor() + rec = cur.execute( + f"select %(val){fmt_in.value} = %(repr)s::inet," + f" %(val){fmt_in.value}, %(repr)s::inet", + {"val": ipaddress.ip_interface(val), "repr": val}, + ).fetchone() + assert rec[0] is True, f"{rec[1]} != {rec[2]}" + cur.execute( + f"select %{fmt_in.value} = array[null, %s]::inet[]", + ([None, ipaddress.ip_interface(val)], val), + ) + assert cur.fetchone()[0] is True + + +@crdb_skip_cidr +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"]) +def test_network_dump(conn, fmt_in, val): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = %s::cidr", (ipaddress.ip_network(val), val)) + assert cur.fetchone()[0] is True + cur.execute( + f"select %{fmt_in.value} = array[NULL, %s]::cidr[]", + ([None, ipaddress.ip_network(val)], val), + ) + assert cur.fetchone()[0] is True + + +@crdb_skip_cidr +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_network_mixed_size_array(conn, fmt_in): + val = [ + ipaddress.IPv4Network("192.168.0.1/32"), + ipaddress.IPv6Network("::1/128"), + ] + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}", (val,)) + got = cur.fetchone()[0] + assert val == got + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"]) +def test_inet_load_address(conn, fmt_out, val): + addr = ipaddress.ip_address(val.split("/", 1)[0]) + cur = conn.cursor(binary=fmt_out) + + cur.execute("select %s::inet", (val,)) + assert cur.fetchone()[0] == addr + + cur.execute("select array[null, %s::inet]", (val,)) + assert cur.fetchone()[0] == [None, addr] + + stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["inet"]) + (got,) = copy.read_row() + + assert got == addr + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"]) +def test_inet_load_network(conn, fmt_out, val): + pyval = ipaddress.ip_interface(val) + cur = conn.cursor(binary=fmt_out) + + cur.execute("select %s::inet", (val,)) + assert cur.fetchone()[0] == pyval + + cur.execute("select array[null, %s::inet]", (val,)) + assert cur.fetchone()[0] == [None, pyval] + + stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["inet"]) + (got,) = copy.read_row() + + assert got == pyval + + +@crdb_skip_cidr +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"]) +def test_cidr_load(conn, fmt_out, val): + pyval = ipaddress.ip_network(val) + cur = conn.cursor(binary=fmt_out) + + cur.execute("select %s::cidr", (val,)) + assert cur.fetchone()[0] == pyval + + cur.execute("select array[null, %s::cidr]", (val,)) + assert cur.fetchone()[0] == [None, pyval] + + stmt = sql.SQL("copy (select {}::cidr) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["cidr"]) + (got,) = copy.read_row() + + assert got == pyval diff --git a/tests/types/test_none.py b/tests/types/test_none.py new file mode 100644 index 0000000..4c008fd --- /dev/null +++ b/tests/types/test_none.py @@ -0,0 +1,12 @@ +from psycopg import sql +from psycopg.adapt import Transformer, PyFormat + + +def test_quote_none(conn): + + tx = Transformer() + assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL" + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None))) + assert cur.fetchone()[0] is None diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py new file mode 100644 index 0000000..a27bc84 --- /dev/null +++ b/tests/types/test_numeric.py @@ -0,0 +1,625 @@ +import enum +from decimal import Decimal +from math import isnan, isinf, exp + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg.adapt import Transformer, PyFormat +from psycopg.types.numeric import FloatLoader + +from ..fix_crdb import is_crdb + +# +# Tests with int +# + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::int"), + (1, "'1'::int"), + (-1, "'-1'::int"), + (42, "'42'::smallint"), + (-42, "'-42'::smallint"), + (int(2**63 - 1), "'9223372036854775807'::bigint"), + (int(-(2**63)), "'-9223372036854775808'::bigint"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_int(conn, val, expr, fmt_in): + assert isinstance(val, int) + cur = conn.cursor() + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::smallint"), + (1, "'1'::smallint"), + (-1, "'-1'::smallint"), + (42, "'42'::smallint"), + (-42, "'-42'::smallint"), + (int(2**15 - 1), f"'{2 ** 15 - 1}'::smallint"), + (int(-(2**15)), f"'{-2 ** 15}'::smallint"), + (int(2**15), f"'{2 ** 15}'::integer"), + (int(-(2**15) - 1), f"'{-2 ** 15 - 1}'::integer"), + (int(2**31 - 1), f"'{2 ** 31 - 1}'::integer"), + (int(-(2**31)), f"'{-2 ** 31}'::integer"), + (int(2**31), f"'{2 ** 31}'::bigint"), + (int(-(2**31) - 1), f"'{-2 ** 31 - 1}'::bigint"), + (int(2**63 - 1), f"'{2 ** 63 - 1}'::bigint"), + (int(-(2**63)), f"'{-2 ** 63}'::bigint"), + (int(2**63), f"'{2 ** 63}'::numeric"), + (int(-(2**63) - 1), f"'{-2 ** 63 - 1}'::numeric"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_int_subtypes(conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + cur.execute( + f"select {expr} = %(v){fmt_in.value}, {expr}::text, %(v){fmt_in.value}::text", + {"v": val}, + ) + ok, want, got = cur.fetchone() + assert got == want + assert ok + + +class MyEnum(enum.IntEnum): + foo = 42 + + +class MyMixinEnum(enum.IntEnum): + foo = 42000000 + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("enum", [MyEnum, MyMixinEnum]) +def test_dump_enum(conn, fmt_in, enum): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}", (enum.foo,)) + (res,) = cur.fetchone() + assert res == enum.foo.value + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, b"0"), + (1, b"1"), + (-1, b" -1"), + (42, b"42"), + (-42, b" -42"), + (int(2**63 - 1), b"9223372036854775807"), + (int(-(2**63)), b" -9223372036854775808"), + (int(2**63), b"9223372036854775808"), + (int(-(2**63 + 1)), b" -9223372036854775809"), + (int(2**100), b"1267650600228229401496703205376"), + (int(-(2**100)), b" -1267650600228229401496703205376"), + ], +) +def test_quote_int(conn, val, expr): + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + assert cur.fetchone() == (val, -val) + + +@pytest.mark.parametrize( + "val, pgtype, want", + [ + ("0", "integer", 0), + ("1", "integer", 1), + ("-1", "integer", -1), + ("0", "int2", 0), + ("0", "int4", 0), + ("0", "int8", 0), + ("0", "integer", 0), + ("0", "oid", 0), + # bounds + ("-32768", "smallint", -32768), + ("+32767", "smallint", 32767), + ("-2147483648", "integer", -2147483648), + ("+2147483647", "integer", 2147483647), + ("-9223372036854775808", "bigint", -9223372036854775808), + ("9223372036854775807", "bigint", 9223372036854775807), + ("4294967295", "oid", 4294967295), + ], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_int(conn, val, pgtype, want, fmt_out): + if pgtype == "integer" and is_crdb(conn): + pgtype = "int4" # "integer" is "int8" on crdb + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select %s::{pgtype}", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid + result = cur.fetchone()[0] + assert result == want + assert type(result) is type(want) + + # arrays work too + cur.execute(f"select array[%s::{pgtype}]", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid + result = cur.fetchone()[0] + assert result == [want] + assert type(result[0]) is type(want) + + +# +# Tests with float +# + + +@pytest.mark.parametrize( + "val, expr", + [ + (0.0, "'0'"), + (1.0, "'1'"), + (-1.0, "'-1'"), + (float("nan"), "'NaN'"), + (float("inf"), "'Infinity'"), + (float("-inf"), "'-Infinity'"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_float(conn, val, expr, fmt_in): + assert isinstance(val, float) + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = {expr}::float8", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0.0, b"0.0"), + (1.0, b"1.0"), + (10000000000000000.0, b"1e+16"), + (1000000.1, b"1000000.1"), + (-100000.000001, b" -100000.000001"), + (-1.0, b" -1.0"), + (float("nan"), b"'NaN'::float8"), + (float("inf"), b"'Infinity'::float8"), + (float("-inf"), b"'-Infinity'::float8"), + ], +) +def test_quote_float(conn, val, expr): + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + r = cur.fetchone() + if isnan(val): + assert isnan(r[0]) and isnan(r[1]) + else: + if isinstance(r[0], Decimal): + r = tuple(map(float, r)) + + assert r == (val, -val) + + +@pytest.mark.parametrize( + "val, expr", + [ + (exp(1), "exp(1.0)"), + (-exp(1), "-exp(1.0)"), + (1e30, "'1e30'"), + (1e-30, "1e-30"), + (-1e30, "'-1e30'"), + (-1e-30, "-1e-30"), + ], +) +def test_dump_float_approx(conn, val, expr): + assert isinstance(val, float) + cur = conn.cursor() + cur.execute(f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, pgtype, want", + [ + ("0", "float4", 0.0), + ("0.0", "float4", 0.0), + ("42", "float4", 42.0), + ("-42", "float4", -42.0), + ("0.0", "float8", 0.0), + ("0.0", "real", 0.0), + ("0.0", "double precision", 0.0), + ("0.0", "float4", 0.0), + ("nan", "float4", float("nan")), + ("inf", "float4", float("inf")), + ("-inf", "float4", -float("inf")), + ("nan", "float8", float("nan")), + ("inf", "float8", float("inf")), + ("-inf", "float8", -float("inf")), + ], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_float(conn, val, pgtype, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select %s::{pgtype}", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid + result = cur.fetchone()[0] + + def check(result, want): + assert type(result) is type(want) + if isnan(want): + assert isnan(result) + elif isinf(want): + assert isinf(result) + assert (result < 0) is (want < 0) + else: + assert result == want + + check(result, want) + + cur.execute(f"select array[%s::{pgtype}]", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid + result = cur.fetchone()[0] + assert isinstance(result, list) + check(result[0], want) + + +@pytest.mark.parametrize( + "expr, pgtype, want", + [ + ("exp(1.0)", "float4", 2.71828), + ("-exp(1.0)", "float4", -2.71828), + ("exp(1.0)", "float8", 2.71828182845905), + ("-exp(1.0)", "float8", -2.71828182845905), + ("1.42e10", "float4", 1.42e10), + ("-1.42e10", "float4", -1.42e10), + ("1.42e40", "float8", 1.42e40), + ("-1.42e40", "float8", -1.42e40), + ], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_float_approx(conn, expr, pgtype, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::%s" % (expr, pgtype)) + assert cur.pgresult.fformat(0) == fmt_out + result = cur.fetchone()[0] + assert result == pytest.approx(want) + + +@pytest.mark.crdb_skip("copy") +def test_load_float_copy(conn): + cur = conn.cursor(binary=False) + with cur.copy("copy (select 3.14::float8, 'hi'::text) to stdout;") as copy: + copy.set_types(["float8", "text"]) + rec = copy.read_row() + + assert rec[0] == pytest.approx(3.14) + assert rec[1] == "hi" + + +# +# Tests with decimal +# + + +@pytest.mark.parametrize( + "val", + [ + "0", + "-0", + "0.0", + "0.000000000000000000001", + "-0.000000000000000000001", + "nan", + "snan", + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_roundtrip_numeric(conn, val, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + val = Decimal(val) + cur.execute(f"select %{fmt_in.value}", (val,)) + result = cur.fetchone()[0] + assert isinstance(result, Decimal) + if val.is_nan(): + assert result.is_nan() + else: + assert result == val + + +@pytest.mark.parametrize( + "val, expr", + [ + ("0", b"0"), + ("0.0", b"0.0"), + ("0.00000000000000001", b"1E-17"), + ("-0.00000000000000001", b" -1E-17"), + ("nan", b"'NaN'::numeric"), + ("snan", b"'NaN'::numeric"), + ], +) +def test_quote_numeric(conn, val, expr): + val = Decimal(val) + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + r = cur.fetchone() + + if val.is_nan(): + assert isnan(r[0]) and isnan(r[1]) + else: + assert r == (val, -val) + + +@pytest.mark.crdb_skip("binary decimal") +@pytest.mark.parametrize( + "expr", + ["NaN", "1", "1.0", "-1", "0.0", "0.01", "11", "1.1", "1.01", "0", "0.00"] + + [ + "0.0000000", + "0.00001", + "1.00001", + "-1.00000000000000", + "-2.00000000000000", + "1000000000.12345", + "100.123456790000000000000000", + "1.0e-1000", + "1e1000", + "0.000000000000000000000000001", + "1.0000000000000000000000001", + "1000000000000000000000000.001", + "1000000000000000000000000000.001", + "9999999999999999999999999999.9", + ], +) +def test_dump_numeric_binary(conn, expr): + cur = conn.cursor() + val = Decimal(expr) + cur.execute("select %b::text, %s::decimal::text", [val, expr]) + want, got = cur.fetchone() + assert got == want + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt_in", + [ + f + if f != PyFormat.BINARY + else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal")) + for f in PyFormat + ], +) +def test_dump_numeric_exhaustive(conn, fmt_in): + cur = conn.cursor() + + funcs = [ + (lambda i: "1" + "0" * i), + (lambda i: "1" + "0" * i + "." + "0" * i), + (lambda i: "-1" + "0" * i), + (lambda i: "0." + "0" * i + "1"), + (lambda i: "-0." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "10"), + (lambda i: "1" + "0" * i + ".001"), + (lambda i: "9" + "9" * i), + (lambda i: "9" + "." + "9" * i), + (lambda i: "9" + "9" * i + ".9"), + (lambda i: "9" + "9" * i + "." + "9" * i), + (lambda i: "1.1e%s" % i), + (lambda i: "1.1e-%s" % i), + ] + + for i in range(100): + for f in funcs: + expr = f(i) + val = Decimal(expr) + cur.execute(f"select %{fmt_in.value}::text, %s::decimal::text", [val, expr]) + got, want = cur.fetchone() + assert got == want + + +@pytest.mark.pg(">= 14") +@pytest.mark.parametrize( + "val, expr", + [ + ("inf", "Infinity"), + ("-inf", "-Infinity"), + ], +) +def test_dump_numeric_binary_inf(conn, val, expr): + cur = conn.cursor() + val = Decimal(val) + cur.execute("select %b", [val]) + + +@pytest.mark.parametrize( + "expr", + ["nan", "0", "1", "-1", "0.0", "0.01"] + + [ + "0.0000000", + "-1.00000000000000", + "-2.00000000000000", + "1000000000.12345", + "100.123456790000000000000000", + "1.0e-1000", + "1e1000", + "0.000000000000000000000000001", + "1.0000000000000000000000001", + "1000000000000000000000000.001", + "1000000000000000000000000000.001", + "9999999999999999999999999999.9", + ], +) +def test_load_numeric_binary(conn, expr): + cur = conn.cursor(binary=1) + res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] + val = Decimal(expr) + if val.is_nan(): + assert res.is_nan() + else: + assert res == val + if "e" not in expr: + assert str(res) == str(val) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_numeric_exhaustive(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + + funcs = [ + (lambda i: "1" + "0" * i), + (lambda i: "1" + "0" * i + "." + "0" * i), + (lambda i: "-1" + "0" * i), + (lambda i: "0." + "0" * i + "1"), + (lambda i: "-0." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "10"), + (lambda i: "1" + "0" * i + ".001"), + (lambda i: "9" + "9" * i), + (lambda i: "9" + "." + "9" * i), + (lambda i: "9" + "9" * i + ".9"), + (lambda i: "9" + "9" * i + "." + "9" * i), + ] + + for i in range(100): + for f in funcs: + snum = f(i) + want = Decimal(snum) + got = cur.execute(f"select '{snum}'::decimal").fetchone()[0] + assert want == got + assert str(want) == str(got) + + +@pytest.mark.pg(">= 14") +@pytest.mark.parametrize( + "val, expr", + [ + ("inf", "Infinity"), + ("-inf", "-Infinity"), + ], +) +def test_load_numeric_binary_inf(conn, val, expr): + cur = conn.cursor(binary=1) + res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] + val = Decimal(val) + assert res == val + + +@pytest.mark.parametrize( + "val", + [ + "0", + "0.0", + "0.000000000000000000001", + "-0.000000000000000000001", + "nan", + ], +) +def test_numeric_as_float(conn, val): + cur = conn.cursor() + cur.adapters.register_loader("numeric", FloatLoader) + + val = Decimal(val) + cur.execute("select %s as val", (val,)) + result = cur.fetchone()[0] + assert isinstance(result, float) + if val.is_nan(): + assert isnan(result) + else: + assert result == pytest.approx(float(val)) + + # the customization works with arrays too + cur.execute("select %s as arr", ([val],)) + result = cur.fetchone()[0] + assert isinstance(result, list) + assert isinstance(result[0], float) + if val.is_nan(): + assert isnan(result[0]) + else: + assert result[0] == pytest.approx(float(val)) + + +# +# Mixed tests +# + + +@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"]) +def test_minus_minus(conn, pgtype): + cur = conn.cursor() + cast = f"::{pgtype}" if pgtype is not None else "" + cur.execute(f"select -%s{cast}", [-1]) + result = cur.fetchone()[0] + assert result == 1 + + +@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"]) +def test_minus_minus_quote(conn, pgtype): + cur = conn.cursor() + cast = f"::{pgtype}" if pgtype is not None else "" + cur.execute(sql.SQL("select -{}{}").format(sql.Literal(-1), sql.SQL(cast))) + result = cur.fetchone()[0] + assert result == 1 + + +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split()) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.numeric, wrapper) + obj = wrapper(1) + cur = conn.execute( + f"select %(obj){fmt_in.value} = 1, %(obj){fmt_in.value}", {"obj": obj} + ) + rec = cur.fetchone() + assert rec[0], rec[1] + + +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split()) +def test_dump_wrapper_oid(wrapper): + wrapper = getattr(psycopg.types.numeric, wrapper) + base = wrapper.__mro__[1] + assert base in (int, float) + n = base(3.14) + assert str(wrapper(n)) == str(n) + assert repr(wrapper(n)) == f"{wrapper.__name__}({n})" + + +@pytest.mark.crdb("skip", reason="all types returned as bigint? TODOCRDB") +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split()) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_repr_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.numeric, wrapper) + cur = conn.execute(f"select pg_typeof(%{fmt_in.value})::oid", [wrapper(0)]) + oid = cur.fetchone()[0] + assert oid == psycopg.postgres.types[wrapper.__name__.lower()].oid + + +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize( + "typename", + "integer int2 int4 int8 float4 float8 numeric".split() + ["double precision"], +) +def test_oid_lookup(conn, typename, fmt_out): + dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out) + assert dumper.oid == conn.adapters.types[typename].oid + assert dumper.format == fmt_out diff --git a/tests/types/test_range.py b/tests/types/test_range.py new file mode 100644 index 0000000..1efd398 --- /dev/null +++ b/tests/types/test_range.py @@ -0,0 +1,677 @@ +import pickle +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg import pq, sql +from psycopg import errors as e +from psycopg.adapt import PyFormat +from psycopg.types import range as range_module +from psycopg.types.range import Range, RangeInfo, register_range + +from ..utils import eur +from ..fix_crdb import is_crdb, crdb_skip_message + +pytestmark = pytest.mark.crdb_skip("range") + +type2sub = { + "int4range": "int4", + "int8range": "int8", + "numrange": "numeric", + "daterange": "date", + "tsrange": "timestamp", + "tstzrange": "timestamptz", +} + +tzinfo = dt.timezone(dt.timedelta(hours=2)) + +samples = [ + ("int4range", None, None, "()"), + ("int4range", 10, 20, "[]"), + ("int4range", -(2**31), (2**31) - 1, "[)"), + ("int8range", None, None, "()"), + ("int8range", 10, 20, "[)"), + ("int8range", -(2**63), (2**63) - 1, "[)"), + ("numrange", Decimal(-100), Decimal("100.123"), "(]"), + ("numrange", Decimal(100), None, "()"), + ("numrange", None, Decimal(100), "()"), + ("daterange", dt.date(2000, 1, 1), dt.date(2020, 1, 1), "[)"), + ( + "tsrange", + dt.datetime(2000, 1, 1, 00, 00), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999), + "[]", + ), + ( + "tstzrange", + dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + "()", + ), +] + +range_names = """ + int4range int8range numrange daterange tsrange tstzrange + """.split() + +range_classes = """ + Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange + """.split() + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty(conn, pgtype, fmt_in): + r = Range(empty=True) # type: ignore[var-annotated] + cur = conn.execute(f"select 'empty'::{pgtype} = %{fmt_in.value}", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", range_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(range_module, wrapper) + r = wrapper(empty=True) + cur = conn.execute(f"select 'empty' = %{fmt_in.value}", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize( + "fmt_in", + [ + PyFormat.AUTO, + PyFormat.TEXT, + # There are many ways to work around this (use text, use a cast on the + # placeholder, use specific Range subclasses). + pytest.param( + PyFormat.BINARY, + marks=pytest.mark.xfail( + reason="can't dump an array of untypes binary range without cast" + ), + ), + ], +) +def test_dump_builtin_array(conn, pgtype, fmt_in): + r1 = Range(empty=True) # type: ignore[var-annotated] + r2 = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.execute( + f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in.value}", + ([r1, r2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in): + r1 = Range(empty=True) # type: ignore[var-annotated] + r2 = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.execute( + f"select array['empty'::{pgtype}, '(,)'::{pgtype}] " + f"= %{fmt_in.value}::{pgtype}[]", + ([r1, r2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", range_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(range_module, wrapper) + r1 = wrapper(empty=True) + r2 = wrapper(bounds="()") + cur = conn.execute(f"""select '{{empty,"(,)"}}' = %{fmt_in.value}""", ([r1, r2],)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype, min, max, bounds", samples) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in): + r = Range(min, max, bounds) # type: ignore[var-annotated] + sub = type2sub[pgtype] + cur = conn.execute( + f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %{fmt_in.value}", + (min, max, bounds, r), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_empty(conn, pgtype, fmt_out): + r = Range(empty=True) # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone() + assert type(got) is Range + assert got == r + assert not got + assert got.isempty + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_inf(conn, pgtype, fmt_out): + r = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select '(,)'::{pgtype}").fetchone() + assert type(got) is Range + assert got == r + assert got + assert not got.isempty + assert got.lower_inf + assert got.upper_inf + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_array(conn, pgtype, fmt_out): + r1 = Range(empty=True) # type: ignore[var-annotated] + r2 = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select array['empty'::{pgtype}, '(,)'::{pgtype}]").fetchone() + assert got == [r1, r2] + + +@pytest.mark.parametrize("pgtype, min, max, bounds", samples) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out): + r = Range(min, max, bounds) # type: ignore[var-annotated] + sub = type2sub[pgtype] + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)) + # normalise discrete ranges + if r.upper_inc and isinstance(r.upper, int): + bounds = "[)" if r.lower_inc else "()" + r = type(r)(r.lower, r.upper + 1, bounds) + assert cur.fetchone()[0] == r + + +@pytest.mark.parametrize( + "min, max, bounds", + [ + ("2000,1,1", "2001,1,1", "[)"), + ("2000,1,1", None, "[)"), + (None, "2001,1,1", "()"), + (None, None, "()"), + (None, None, "empty"), + ], +) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in(conn, min, max, bounds, format): + cur = conn.cursor() + cur.execute("create table copyrange (id serial primary key, r daterange)") + + if bounds != "empty": + min = dt.date(*map(int, min.split(","))) if min else None + max = dt.date(*map(int, max.split(","))) if max else None + r = Range[dt.date](min, max, bounds) + else: + r = Range(empty=True) + + try: + with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy: + copy.write_row([r]) + except e.ProtocolViolation: + if not min and not max and format == pq.Format.BINARY: + pytest.xfail("TODO: add annotation to dump ranges with no type info") + else: + raise + + rec = cur.execute("select r from copyrange order by id").fetchone() + assert rec[0] == r + + +@pytest.mark.parametrize("bounds", "() empty".split()) +@pytest.mark.parametrize("wrapper", range_classes) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_wrappers(conn, bounds, wrapper, format): + cur = conn.cursor() + cur.execute("create table copyrange (id serial primary key, r daterange)") + + cls = getattr(range_module, wrapper) + r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds) + + with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy: + copy.write_row([r]) + + rec = cur.execute("select r from copyrange order by id").fetchone() + assert rec[0] == r + + +@pytest.mark.parametrize("bounds", "() empty".split()) +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_set_type(conn, bounds, pgtype, format): + cur = conn.cursor() + cur.execute(f"create table copyrange (id serial primary key, r {pgtype})") + + if bounds == "empty": + r = Range(empty=True) # type: ignore[var-annotated] + else: + r = Range(None, None, bounds) + + with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy: + copy.set_types([pgtype]) + copy.write_row([r]) + + rec = cur.execute("select r from copyrange order by id").fetchone() + assert rec[0] == r + + +@pytest.fixture(scope="session") +def testrange(svcconn): + create_test_range(svcconn) + + +def create_test_range(conn): + if is_crdb(conn): + pytest.skip(crdb_skip_message("range")) + + conn.execute( + """ + create schema if not exists testschema; + + drop type if exists testrange cascade; + drop type if exists testschema.testrange cascade; + + create type testrange as range (subtype = text, collation = "C"); + create type testschema.testrange as range (subtype = float8); + """ + ) + + +fetch_cases = [ + ("testrange", "text"), + ("testschema.testrange", "float8"), + (sql.Identifier("testrange"), "text"), + (sql.Identifier("testschema", "testrange"), "float8"), +] + + +@pytest.mark.parametrize("name, subtype", fetch_cases) +def test_fetch_info(conn, testrange, name, subtype): + info = RangeInfo.fetch(conn, name) + assert info.name == "testrange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == conn.adapters.types[subtype].oid + + +def test_fetch_info_not_found(conn): + assert RangeInfo.fetch(conn, "nosuchrange") is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, subtype", fetch_cases) +async def test_fetch_info_async(aconn, testrange, name, subtype): + info = await RangeInfo.fetch(aconn, name) + assert info.name == "testrange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == aconn.adapters.types[subtype].oid + + +@pytest.mark.asyncio +async def test_fetch_info_not_found_async(aconn): + assert await RangeInfo.fetch(aconn, "nosuchrange") is None + + +def test_dump_custom_empty(conn, testrange): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + + r = Range[str](empty=True) + cur = conn.execute("select 'empty'::testrange = %s", (r,)) + assert cur.fetchone()[0] is True + + +def test_dump_quoting(conn, testrange): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + cur = conn.cursor() + for i in range(1, 254): + cur.execute( + """ + select ascii(lower(%(r)s)) = %(low)s + and ascii(upper(%(r)s)) = %(up)s + """, + {"r": Range(chr(i), chr(i + 1)), "low": i, "up": i + 1}, + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_custom_empty(conn, testrange, fmt_out): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute("select 'empty'::testrange").fetchone() + assert isinstance(got, Range) + assert got.isempty + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_quoting(conn, testrange, fmt_out): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + cur = conn.cursor(binary=fmt_out) + for i in range(1, 254): + cur.execute( + "select testrange(chr(%(low)s::int), chr(%(up)s::int))", + {"low": i, "up": i + 1}, + ) + got: Range[str] = cur.fetchone()[0] + assert isinstance(got, Range) + assert got.lower and ord(got.lower) == i + assert got.upper and ord(got.upper) == i + 1 + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_mixed_array_types(conn, fmt_out): + conn.execute("create table testmix (a daterange[], b tstzrange[])") + r1 = Range(dt.date(2000, 1, 1), dt.date(2001, 1, 1), "[)") + r2 = Range( + dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), + dt.datetime(2001, 1, 1, tzinfo=dt.timezone.utc), + "[)", + ) + conn.execute("insert into testmix values (%s, %s)", [[r1], [r2]]) + got = conn.execute("select * from testmix").fetchone() + assert got == ([r1], [r2]) + + +class TestRangeObject: + def test_noparam(self): + r = Range() # type: ignore[var-annotated] + + assert not r.isempty + assert r.lower is None + assert r.upper is None + assert r.lower_inf + assert r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + def test_empty(self): + r = Range(empty=True) # type: ignore[var-annotated] + + assert r.isempty + assert r.lower is None + assert r.upper is None + assert not r.lower_inf + assert not r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + def test_nobounds(self): + r = Range(10, 20) + assert r.lower == 10 + assert r.upper == 20 + assert not r.isempty + assert not r.lower_inf + assert not r.upper_inf + assert r.lower_inc + assert not r.upper_inc + + def test_bounds(self): + for bounds, lower_inc, upper_inc in [ + ("[)", True, False), + ("(]", False, True), + ("()", False, False), + ("[]", True, True), + ]: + r = Range(10, 20, bounds) + assert r.bounds == bounds + assert r.lower == 10 + assert r.upper == 20 + assert not r.isempty + assert not r.lower_inf + assert not r.upper_inf + assert r.lower_inc == lower_inc + assert r.upper_inc == upper_inc + + def test_keywords(self): + r = Range(upper=20) + r.lower is None + r.upper == 20 + assert not r.isempty + assert r.lower_inf + assert not r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + r = Range(lower=10, bounds="(]") + r.lower == 10 + r.upper is None + assert not r.isempty + assert not r.lower_inf + assert r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + def test_bad_bounds(self): + with pytest.raises(ValueError): + Range(bounds="(") + with pytest.raises(ValueError): + Range(bounds="[}") + + def test_in(self): + r = Range[int](empty=True) + assert 10 not in r + assert "x" not in r # type: ignore[operator] + + r = Range() + assert 10 in r + + r = Range(lower=10, bounds="[)") + assert 9 not in r + assert 10 in r + assert 11 in r + + r = Range(lower=10, bounds="()") + assert 9 not in r + assert 10 not in r + assert 11 in r + + r = Range(upper=20, bounds="()") + assert 19 in r + assert 20 not in r + assert 21 not in r + + r = Range(upper=20, bounds="(]") + assert 19 in r + assert 20 in r + assert 21 not in r + + r = Range(10, 20) + assert 9 not in r + assert 10 in r + assert 11 in r + assert 19 in r + assert 20 not in r + assert 21 not in r + + r = Range(10, 20, "(]") + assert 9 not in r + assert 10 not in r + assert 11 in r + assert 19 in r + assert 20 in r + assert 21 not in r + + r = Range(20, 10) + assert 9 not in r + assert 10 not in r + assert 11 not in r + assert 19 not in r + assert 20 not in r + assert 21 not in r + + def test_nonzero(self): + assert Range() + assert Range(10, 20) + assert not Range(empty=True) + + def test_eq_hash(self): + def assert_equal(r1, r2): + assert r1 == r2 + assert hash(r1) == hash(r2) + + assert_equal(Range(empty=True), Range(empty=True)) + assert_equal(Range(), Range()) + assert_equal(Range(10, None), Range(10, None)) + assert_equal(Range(10, 20), Range(10, 20)) + assert_equal(Range(10, 20), Range(10, 20, "[)")) + assert_equal(Range(10, 20, "[]"), Range(10, 20, "[]")) + + def assert_not_equal(r1, r2): + assert r1 != r2 + assert hash(r1) != hash(r2) + + assert_not_equal(Range(10, 20), Range(10, 21)) + assert_not_equal(Range(10, 20), Range(11, 20)) + assert_not_equal(Range(10, 20, "[)"), Range(10, 20, "[]")) + + def test_eq_wrong_type(self): + assert Range(10, 20) != () + + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. + + def test_lt_ordering(self): + assert Range(empty=True) < Range(0, 4) + assert not Range(1, 2) < Range(0, 4) + assert Range(0, 4) < Range(1, 2) + assert not Range(1, 2) < Range() + assert Range() < Range(1, 2) + assert not Range(1) < Range(upper=1) + assert not Range() < Range() + assert not Range(empty=True) < Range(empty=True) + assert not Range(1, 2) < Range(1, 2) + with pytest.raises(TypeError): + assert 1 < Range(1, 2) + with pytest.raises(TypeError): + assert not Range(1, 2) < 1 + + def test_gt_ordering(self): + assert not Range(empty=True) > Range(0, 4) + assert Range(1, 2) > Range(0, 4) + assert not Range(0, 4) > Range(1, 2) + assert Range(1, 2) > Range() + assert not Range() > Range(1, 2) + assert Range(1) > Range(upper=1) + assert not Range() > Range() + assert not Range(empty=True) > Range(empty=True) + assert not Range(1, 2) > Range(1, 2) + with pytest.raises(TypeError): + assert not 1 > Range(1, 2) + with pytest.raises(TypeError): + assert Range(1, 2) > 1 + + def test_le_ordering(self): + assert Range(empty=True) <= Range(0, 4) + assert not Range(1, 2) <= Range(0, 4) + assert Range(0, 4) <= Range(1, 2) + assert not Range(1, 2) <= Range() + assert Range() <= Range(1, 2) + assert not Range(1) <= Range(upper=1) + assert Range() <= Range() + assert Range(empty=True) <= Range(empty=True) + assert Range(1, 2) <= Range(1, 2) + with pytest.raises(TypeError): + assert 1 <= Range(1, 2) + with pytest.raises(TypeError): + assert not Range(1, 2) <= 1 + + def test_ge_ordering(self): + assert not Range(empty=True) >= Range(0, 4) + assert Range(1, 2) >= Range(0, 4) + assert not Range(0, 4) >= Range(1, 2) + assert Range(1, 2) >= Range() + assert not Range() >= Range(1, 2) + assert Range(1) >= Range(upper=1) + assert Range() >= Range() + assert Range(empty=True) >= Range(empty=True) + assert Range(1, 2) >= Range(1, 2) + with pytest.raises(TypeError): + assert not 1 >= Range(1, 2) + with pytest.raises(TypeError): + (Range(1, 2) >= 1) + + def test_pickling(self): + r = Range(0, 4) + assert pickle.loads(pickle.dumps(r)) == r + + def test_str(self): + """ + Range types should have a short and readable ``str`` implementation. + """ + expected = [ + "(0, 4)", + "[0, 4]", + "(0, 4]", + "[0, 4)", + "empty", + ] + results = [] + + for bounds in ("()", "[]", "(]", "[)"): + r = Range(0, 4, bounds=bounds) + results.append(str(r)) + + r = Range(empty=True) + results.append(str(r)) + assert results == expected + + def test_str_datetime(self): + """ + Date-Time ranges should return a human-readable string as well on + string conversion. + """ + tz = dt.timezone(dt.timedelta(hours=-5)) + r = Range( + dt.datetime(2010, 1, 1, tzinfo=tz), + dt.datetime(2011, 1, 1, tzinfo=tz), + ) + expected = "[2010-01-01 00:00:00-05:00, 2011-01-01 00:00:00-05:00)" + result = str(r) + assert result == expected + + def test_exclude_inf_bounds(self): + r = Range(None, 10, "[]") + assert r.lower is None + assert not r.lower_inc + assert r.bounds == "(]" + + r = Range(10, None, "[]") + assert r.upper is None + assert not r.upper_inc + assert r.bounds == "[)" + + r = Range(None, None, "[]") + assert r.lower is None + assert not r.lower_inc + assert r.upper is None + assert not r.upper_inc + assert r.bounds == "()" + + +def test_no_info_error(conn): + with pytest.raises(TypeError, match="range"): + register_range(None, conn) # type: ignore[arg-type] + + +@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"]) +def test_literal_invalid_name(conn, name): + conn.execute("set client_encoding to utf8") + conn.execute(f'create type "{name}" as range (subtype = text)') + info = RangeInfo.fetch(conn, f'"{name}"') + register_range(info, conn) + obj = Range("a", "z", "[]") + assert sql.Literal(obj).as_string(conn) == f"'[a,z]'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format(obj)) + assert cur.fetchone()[0] == obj diff --git a/tests/types/test_shapely.py b/tests/types/test_shapely.py new file mode 100644 index 0000000..0f7007e --- /dev/null +++ b/tests/types/test_shapely.py @@ -0,0 +1,152 @@ +import pytest + +import psycopg +from psycopg.pq import Format +from psycopg.types import TypeInfo +from psycopg.adapt import PyFormat + +pytest.importorskip("shapely") + +from shapely.geometry import Point, Polygon, MultiPolygon # noqa: E402 +from psycopg.types.shapely import register_shapely # noqa: E402 + +pytestmark = [ + pytest.mark.postgis, + pytest.mark.crdb("skip"), +] + +# real example, with CRS and "holes" +MULTIPOLYGON_GEOJSON = """ +{ + "type":"MultiPolygon", + "crs":{ + "type":"name", + "properties":{ + "name":"EPSG:3857" + } + }, + "coordinates":[ + [ + [ + [89574.61111389, 6894228.638802719], + [89576.815239808, 6894208.60747024], + [89576.904295401, 6894207.820852726], + [89577.99522641, 6894208.022080451], + [89577.961830563, 6894209.229446936], + [89589.227363031, 6894210.601454523], + [89594.615226386, 6894161.849595264], + [89600.314784314, 6894111.37846976], + [89651.187791607, 6894116.774968589], + [89648.49385993, 6894140.226914071], + [89642.92788539, 6894193.423936413], + [89639.721884055, 6894224.08372821], + [89589.283022777, 6894218.431048969], + [89588.192091767, 6894230.248628867], + [89574.61111389, 6894228.638802719] + ], + [ + [89610.344670435, 6894182.466199101], + [89625.985058891, 6894184.258949757], + [89629.547282597, 6894153.270030369], + [89613.918026089, 6894151.458993318], + [89610.344670435, 6894182.466199101] + ] + ] + ] +}""" + +SAMPLE_POINT_GEOJSON = '{"type":"Point","coordinates":[1.2, 3.4]}' + + +@pytest.fixture +def shapely_conn(conn, svcconn): + try: + with svcconn.transaction(): + svcconn.execute("create extension if not exists postgis") + except psycopg.Error as e: + pytest.skip(f"can't create extension postgis: {e}") + + info = TypeInfo.fetch(conn, "geometry") + assert info + register_shapely(info, conn) + return conn + + +def test_no_adapter(conn): + point = Point(1.2, 3.4) + with pytest.raises(psycopg.ProgrammingError, match="cannot adapt type 'Point'"): + conn.execute("SELECT pg_typeof(%s)", [point]).fetchone()[0] + + +def test_no_info_error(conn): + from psycopg.types.shapely import register_shapely + + with pytest.raises(TypeError, match="postgis.*extension"): + register_shapely(None, conn) # type: ignore[arg-type] + + +def test_with_adapter(shapely_conn): + SAMPLE_POINT = Point(1.2, 3.4) + SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)]) + + assert ( + shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POINT]).fetchone()[0] + == "geometry" + ) + + assert ( + shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POLYGON]).fetchone()[0] + == "geometry" + ) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", Format) +def test_write_read_shape(shapely_conn, fmt_in, fmt_out): + SAMPLE_POINT = Point(1.2, 3.4) + SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)]) + + with shapely_conn.cursor(binary=fmt_out) as cur: + cur.execute( + """ + create table sample_geoms( + id INTEGER PRIMARY KEY, + geom geometry + ) + """ + ) + cur.execute( + f"insert into sample_geoms(id, geom) VALUES(1, %{fmt_in})", + (SAMPLE_POINT,), + ) + cur.execute( + f"insert into sample_geoms(id, geom) VALUES(2, %{fmt_in})", + (SAMPLE_POLYGON,), + ) + + cur.execute("select geom from sample_geoms where id=1") + result = cur.fetchone()[0] + assert result == SAMPLE_POINT + + cur.execute("select geom from sample_geoms where id=2") + result = cur.fetchone()[0] + assert result == SAMPLE_POLYGON + + +@pytest.mark.parametrize("fmt_out", Format) +def test_match_geojson(shapely_conn, fmt_out): + SAMPLE_POINT = Point(1.2, 3.4) + with shapely_conn.cursor(binary=fmt_out) as cur: + cur.execute( + """ + select ST_GeomFromGeoJSON(%s) + """, + (SAMPLE_POINT_GEOJSON,), + ) + result = cur.fetchone()[0] + # clone the coordinates to have a list instead of a shapely wrapper + assert result.coords[:] == SAMPLE_POINT.coords[:] + # + cur.execute("select ST_GeomFromGeoJSON(%s)", (MULTIPOLYGON_GEOJSON,)) + result = cur.fetchone()[0] + assert isinstance(result, MultiPolygon) diff --git a/tests/types/test_string.py b/tests/types/test_string.py new file mode 100644 index 0000000..d23e5e0 --- /dev/null +++ b/tests/types/test_string.py @@ -0,0 +1,307 @@ +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg import errors as e +from psycopg.adapt import PyFormat +from psycopg import Binary + +from ..utils import eur +from ..fix_crdb import crdb_encoding, crdb_scs_off + +# +# tests with text +# + + +def crdb_bpchar(*args): + return pytest.param(*args, marks=pytest.mark.crdb("skip", reason="bpchar")) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_1char(conn, fmt_in): + cur = conn.cursor() + for i in range(1, 256): + cur.execute(f"select %{fmt_in.value} = chr(%s)", (chr(i), i)) + assert cur.fetchone()[0] is True, chr(i) + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_quote_1char(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + conn.execute("set escape_string_warning to on") + + cur = conn.cursor() + query = sql.SQL("select {ch} = chr(%s)") + for i in range(1, 256): + if chr(i) == "%": + continue + cur.execute(query.format(ch=sql.Literal(chr(i))), (i,)) + assert cur.fetchone()[0] is True, chr(i) + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + +@pytest.mark.crdb("skip", reason="can deal with 0 strings") +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_zero(conn, fmt_in): + cur = conn.cursor() + s = "foo\x00bar" + with pytest.raises(psycopg.DataError): + cur.execute(f"select %{fmt_in.value}::text", (s,)) + + +def test_quote_zero(conn): + cur = conn.cursor() + s = "foo\x00bar" + with pytest.raises(psycopg.DataError): + cur.execute(sql.SQL("select {}").format(sql.Literal(s))) + + +# the only way to make this pass is to reduce %% -> % every time +# not only when there are query arguments +# see https://github.com/psycopg/psycopg2/issues/825 +@pytest.mark.xfail +def test_quote_percent(conn): + cur = conn.cursor() + cur.execute(sql.SQL("select {ch}").format(ch=sql.Literal("%"))) + assert cur.fetchone()[0] == "%" + + cur.execute( + sql.SQL("select {ch} = chr(%s)").format(ch=sql.Literal("%")), + (ord("%"),), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "typename", ["text", "varchar", "name", crdb_bpchar("bpchar"), '"char"'] +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_1char(conn, typename, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + if typename == '"char"' and i > 127: + # for char > 128 the client receives only 194 or 195. + continue + + cur.execute(f"select chr(%s)::{typename}", (i,)) + res = cur.fetchone()[0] + assert res == chr(i) + + assert cur.pgresult.fformat(0) == fmt_out + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize( + "encoding", ["utf8", crdb_encoding("latin9"), crdb_encoding("sql_ascii")] +) +def test_dump_enc(conn, fmt_in, encoding): + cur = conn.cursor() + + conn.execute(f"set client_encoding to {encoding}") + (res,) = cur.execute(f"select ascii(%{fmt_in.value})", (eur,)).fetchone() + assert res == ord(eur) + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_badenc(conn, fmt_in): + cur = conn.cursor() + + conn.execute("set client_encoding to latin1") + with pytest.raises(UnicodeEncodeError): + cur.execute(f"select %{fmt_in.value}::bytea", (eur,)) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_utf8_badenc(conn, fmt_in): + cur = conn.cursor() + + conn.execute("set client_encoding to utf8") + with pytest.raises(UnicodeEncodeError): + cur.execute(f"select %{fmt_in.value}", ("\uddf8",)) + + +@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT]) +def test_dump_enum(conn, fmt_in): + from enum import Enum + + class MyEnum(str, Enum): + foo = "foo" + bar = "bar" + + cur = conn.cursor() + cur.execute("create type myenum as enum ('foo', 'bar')") + cur.execute("create table with_enum (e myenum)") + cur.execute(f"insert into with_enum (e) values (%{fmt_in.value})", (MyEnum.foo,)) + (res,) = cur.execute("select e from with_enum").fetchone() + assert res == "foo" + + +@pytest.mark.crdb("skip") +@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT]) +def test_dump_text_oid(conn, fmt_in): + conn.autocommit = True + + with pytest.raises(e.IndeterminateDatatype): + conn.execute(f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"]) + conn.adapters.register_dumper(str, psycopg.types.string.StrDumper) + cur = conn.execute( + f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"] + ) + assert cur.fetchone()[0] == "foobar" + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_load_enc(conn, typename, encoding, fmt_out): + cur = conn.cursor(binary=fmt_out) + + conn.execute(f"set client_encoding to {encoding}") + (res,) = cur.execute(f"select chr(%s)::{typename}", [ord(eur)]).fetchone() + assert res == eur + + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + (res,) = copy.read_row() + + assert res == eur + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_load_badenc(conn, typename, fmt_out): + conn.autocommit = True + cur = conn.cursor(binary=fmt_out) + + conn.execute("set client_encoding to latin1") + with pytest.raises(psycopg.DataError): + cur.execute(f"select chr(%s)::{typename}", [ord(eur)]) + + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + with pytest.raises(psycopg.DataError): + copy.read_row() + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_load_ascii(conn, typename, fmt_out): + cur = conn.cursor(binary=fmt_out) + + conn.execute("set client_encoding to sql_ascii") + cur.execute(f"select chr(%s)::{typename}", [ord(eur)]) + assert cur.fetchone()[0] == eur.encode() + + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + (res,) = copy.read_row() + + assert res == eur.encode() + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", crdb_bpchar("bpchar")]) +def test_text_array(conn, typename, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + a = list(map(chr, range(1, 256))) + [eur] + + (res,) = cur.execute(f"select %{fmt_in.value}::{typename}[]", (a,)).fetchone() + assert res == a + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_text_array_ascii(conn, fmt_in, fmt_out): + conn.execute("set client_encoding to sql_ascii") + cur = conn.cursor(binary=fmt_out) + a = list(map(chr, range(1, 256))) + [eur] + exp = [s.encode() for s in a] + (res,) = cur.execute(f"select %{fmt_in.value}::text[]", (a,)).fetchone() + assert res == exp + + +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name"]) +def test_oid_lookup(conn, typename, fmt_out): + dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out) + assert dumper.oid == conn.adapters.types[typename].oid + assert dumper.format == fmt_out + + +# +# tests with bytea +# + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary]) +def test_dump_1byte(conn, fmt_in, pytype): + cur = conn.cursor() + for i in range(0, 256): + obj = pytype(bytes([i])) + cur.execute(f"select %{fmt_in.value} = set_byte('x', 0, %s)", (obj, i)) + assert cur.fetchone()[0] is True, i + + cur.execute(f"select %{fmt_in.value} = array[set_byte('x', 0, %s)]", ([obj], i)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary]) +def test_quote_1byte(conn, scs, pytype): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + conn.execute("set escape_string_warning to on") + + cur = conn.cursor() + query = sql.SQL("select {ch} = set_byte('x', 0, %s)") + for i in range(0, 256): + obj = pytype(bytes([i])) + cur.execute(query.format(ch=sql.Literal(obj)), (i,)) + assert cur.fetchone()[0] is True, i + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_1byte(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(0, 256): + cur.execute("select set_byte('x', 0, %s)", (i,)) + val = cur.fetchone()[0] + assert val == bytes([i]) + + assert isinstance(val, bytes) + assert cur.pgresult.fformat(0) == fmt_out + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_bytea_array(conn, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + a = [bytes(range(0, 256))] + (res,) = cur.execute(f"select %{fmt_in.value}::bytea[]", (a,)).fetchone() + assert res == a diff --git a/tests/types/test_uuid.py b/tests/types/test_uuid.py new file mode 100644 index 0000000..f86f066 --- /dev/null +++ b/tests/types/test_uuid.py @@ -0,0 +1,56 @@ +import sys +from uuid import UUID +import subprocess as sp + +import pytest + +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_uuid_dump(conn, fmt_in): + val = "12345678123456781234567812345679" + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = %s::uuid", (UUID(val), val)) + assert cur.fetchone()[0] is True + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_uuid_load(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + val = "12345678123456781234567812345679" + cur.execute("select %s::uuid", (val,)) + assert cur.fetchone()[0] == UUID(val) + + stmt = sql.SQL("copy (select {}::uuid) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["uuid"]) + (res,) = copy.read_row() + + assert res == UUID(val) + + +@pytest.mark.slow +@pytest.mark.subprocess +def test_lazy_load(dsn): + script = f"""\ +import sys +import psycopg + +assert 'uuid' not in sys.modules + +conn = psycopg.connect({dsn!r}) +with conn.cursor() as cur: + cur.execute("select repeat('1', 32)::uuid") + cur.fetchone() + +conn.close() +assert 'uuid' in sys.modules +""" + + sp.check_call([sys.executable, "-c", script]) diff --git a/tests/typing_example.py b/tests/typing_example.py new file mode 100644 index 0000000..a26ca49 --- /dev/null +++ b/tests/typing_example.py @@ -0,0 +1,176 @@ +# flake8: builtins=reveal_type + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + +from psycopg import Connection, Cursor, ServerCursor, connect, rows +from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor + + +def int_row_factory( + cursor: Union[Cursor[Any], AsyncCursor[Any]] +) -> Callable[[Sequence[int]], int]: + return lambda values: values[0] if values else 42 + + +@dataclass +class Person: + name: str + address: str + + @classmethod + def row_factory( + cls, cursor: Union[Cursor[Any], AsyncCursor[Any]] + ) -> Callable[[Sequence[str]], Person]: + def mkrow(values: Sequence[str]) -> Person: + name, address = values + return cls(name, address) + + return mkrow + + +def kwargsf(*, foo: int, bar: int, baz: int) -> int: + return 42 + + +def argsf(foo: int, bar: int, baz: int) -> float: + return 42.0 + + +def check_row_factory_cursor() -> None: + """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case.""" + conn = connect() + + cur1: Cursor[Any] + cur1 = conn.cursor() + r1: Optional[Any] + r1 = cur1.fetchone() + r1 is not None + + cur2: Cursor[int] + r2: Optional[int] + with conn.cursor(row_factory=int_row_factory) as cur2: + cur2.execute("select 1") + r2 = cur2.fetchone() + r2 and r2 > 0 + + cur3: ServerCursor[Person] + persons: Sequence[Person] + with conn.cursor(name="s", row_factory=Person.row_factory) as cur3: + cur3.execute("select * from persons where name like 'al%'") + persons = cur3.fetchall() + persons[0].address + + +async def async_check_row_factory_cursor() -> None: + """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case.""" + conn = await AsyncConnection.connect() + + cur1: AsyncCursor[Any] + cur1 = conn.cursor() + r1: Optional[Any] + r1 = await cur1.fetchone() + r1 is not None + + cur2: AsyncCursor[int] + r2: Optional[int] + async with conn.cursor(row_factory=int_row_factory) as cur2: + await cur2.execute("select 1") + r2 = await cur2.fetchone() + r2 and r2 > 0 + + cur3: AsyncServerCursor[Person] + persons: Sequence[Person] + async with conn.cursor(name="s", row_factory=Person.row_factory) as cur3: + await cur3.execute("select * from persons where name like 'al%'") + persons = await cur3.fetchall() + persons[0].address + + +def check_row_factory_connection() -> None: + """Type-check connect(..., row_factory=<MyRowFactory>) or + Connection.row_factory cases. + """ + conn1: Connection[int] + cur1: Cursor[int] + r1: Optional[int] + conn1 = connect(row_factory=int_row_factory) + cur1 = conn1.execute("select 1") + r1 = cur1.fetchone() + r1 != 0 + with conn1.cursor() as cur1: + cur1.execute("select 2") + + conn2: Connection[Person] + cur2: Cursor[Person] + r2: Optional[Person] + conn2 = connect(row_factory=Person.row_factory) + cur2 = conn2.execute("select * from persons") + r2 = cur2.fetchone() + r2 and r2.name + with conn2.cursor() as cur2: + cur2.execute("select 2") + + cur3: Cursor[Tuple[Any, ...]] + r3: Optional[Tuple[Any, ...]] + conn3 = connect() + cur3 = conn3.execute("select 3") + with conn3.cursor() as cur3: + cur3.execute("select 42") + r3 = cur3.fetchone() + r3 and len(r3) + + +async def async_check_row_factory_connection() -> None: + """Type-check connect(..., row_factory=<MyRowFactory>) or + Connection.row_factory cases. + """ + conn1: AsyncConnection[int] + cur1: AsyncCursor[int] + r1: Optional[int] + conn1 = await AsyncConnection.connect(row_factory=int_row_factory) + cur1 = await conn1.execute("select 1") + r1 = await cur1.fetchone() + r1 != 0 + async with conn1.cursor() as cur1: + await cur1.execute("select 2") + + conn2: AsyncConnection[Person] + cur2: AsyncCursor[Person] + r2: Optional[Person] + conn2 = await AsyncConnection.connect(row_factory=Person.row_factory) + cur2 = await conn2.execute("select * from persons") + r2 = await cur2.fetchone() + r2 and r2.name + async with conn2.cursor() as cur2: + await cur2.execute("select 2") + + cur3: AsyncCursor[Tuple[Any, ...]] + r3: Optional[Tuple[Any, ...]] + conn3 = await AsyncConnection.connect() + cur3 = await conn3.execute("select 3") + async with conn3.cursor() as cur3: + await cur3.execute("select 42") + r3 = await cur3.fetchone() + r3 and len(r3) + + +def check_row_factories() -> None: + conn1 = connect(row_factory=rows.tuple_row) + v1: Tuple[Any, ...] = conn1.execute("").fetchall()[0] + + conn2 = connect(row_factory=rows.dict_row) + v2: Dict[str, Any] = conn2.execute("").fetchall()[0] + + conn3 = connect(row_factory=rows.class_row(Person)) + v3: Person = conn3.execute("").fetchall()[0] + + conn4 = connect(row_factory=rows.args_row(argsf)) + v4: float = conn4.execute("").fetchall()[0] + + conn5 = connect(row_factory=rows.kwargs_row(kwargsf)) + v5: int = conn5.execute("").fetchall()[0] + + v1, v2, v3, v4, v5 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..871f65d --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,179 @@ +import gc +import re +import sys +import operator +from typing import Callable, Optional, Tuple + +import pytest + +eur = "\u20ac" + + +def check_libpq_version(got, want): + """ + Verify if the libpq version is a version accepted. + + This function is called on the tests marked with something like:: + + @pytest.mark.libpq(">= 12") + + and skips the test if the requested version doesn't match what's loaded. + """ + return check_version(got, want, "libpq", postgres_rule=True) + + +def check_postgres_version(got, want): + """ + Verify if the server version is a version accepted. + + This function is called on the tests marked with something like:: + + @pytest.mark.pg(">= 12") + + and skips the test if the server version doesn't match what expected. + """ + return check_version(got, want, "PostgreSQL", postgres_rule=True) + + +def check_version(got, want, whose_version, postgres_rule=True): + pred = VersionCheck.parse(want, postgres_rule=postgres_rule) + pred.whose = whose_version + return pred.get_skip_message(got) + + +class VersionCheck: + """ + Helper to compare a version number with a test spec. + """ + + def __init__( + self, + *, + skip: bool = False, + op: Optional[str] = None, + version_tuple: Tuple[int, ...] = (), + whose: str = "(wanted)", + postgres_rule: bool = False, + ): + self.skip = skip + self.op = op or "==" + self.version_tuple = version_tuple + self.whose = whose + # Treat 10.1 as 10.0.1 + self.postgres_rule = postgres_rule + + @classmethod + def parse(cls, spec: str, *, postgres_rule: bool = False) -> "VersionCheck": + # Parse a spec like "> 9.6", "skip < 21.2.0" + m = re.match( + r"""(?ix) + ^\s* (skip|only)? + \s* (==|!=|>=|<=|>|<)? + \s* (?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)? + \s* $ + """, + spec, + ) + if m is None: + pytest.fail(f"bad wanted version spec: {spec}") + + skip = (m.group(1) or "only").lower() == "skip" + op = m.group(2) + version_tuple = tuple(int(n) for n in m.groups()[2:] if n) + + return cls( + skip=skip, op=op, version_tuple=version_tuple, postgres_rule=postgres_rule + ) + + def get_skip_message(self, version: Optional[int]) -> Optional[str]: + got_tuple = self._parse_int_version(version) + + msg: Optional[str] = None + if self.skip: + if got_tuple: + if not self.version_tuple: + msg = f"skip on {self.whose}" + elif self._match_version(got_tuple): + msg = ( + f"skip on {self.whose} {self.op}" + f" {'.'.join(map(str, self.version_tuple))}" + ) + else: + if not got_tuple: + msg = f"only for {self.whose}" + elif not self._match_version(got_tuple): + if self.version_tuple: + msg = ( + f"only for {self.whose} {self.op}" + f" {'.'.join(map(str, self.version_tuple))}" + ) + else: + msg = f"only for {self.whose}" + + return msg + + _OP_NAMES = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq", "!=": "ne"} + + def _match_version(self, got_tuple: Tuple[int, ...]) -> bool: + if not self.version_tuple: + return True + + version_tuple = self.version_tuple + if self.postgres_rule and version_tuple and version_tuple[0] >= 10: + assert len(version_tuple) <= 2 + version_tuple = version_tuple[:1] + (0,) + version_tuple[1:] + + op: Callable[[Tuple[int, ...], Tuple[int, ...]], bool] + op = getattr(operator, self._OP_NAMES[self.op]) + return op(got_tuple, version_tuple) + + def _parse_int_version(self, version: Optional[int]) -> Tuple[int, ...]: + if version is None: + return () + version, ver_fix = divmod(version, 100) + ver_maj, ver_min = divmod(version, 100) + return (ver_maj, ver_min, ver_fix) + + +def gc_collect(): + """ + gc.collect(), but more insisting. + """ + for i in range(3): + gc.collect() + + +NO_COUNT_TYPES: Tuple[type, ...] = () + +if sys.version_info[:2] == (3, 10): + # On my laptop there are occasional creations of a single one of these objects + # with empty content, which might be some Decimal caching. + # Keeping the guard as strict as possible, to be extended if other types + # or versions are necessary. + try: + from _contextvars import Context # type: ignore + except ImportError: + pass + else: + NO_COUNT_TYPES += (Context,) + + +def gc_count() -> int: + """ + len(gc.get_objects()), with subtleties. + """ + if not NO_COUNT_TYPES: + return len(gc.get_objects()) + + # Note: not using a list comprehension because it pollutes the objects list. + rv = 0 + for obj in gc.get_objects(): + if isinstance(obj, NO_COUNT_TYPES): + continue + rv += 1 + + return rv + + +async def alist(it): + return [i async for i in it] diff --git a/tools/build/build_libpq.sh b/tools/build/build_libpq.sh new file mode 100755 index 0000000..4cc79af --- /dev/null +++ b/tools/build/build_libpq.sh @@ -0,0 +1,173 @@ +#!/bin/bash + +# Build a modern version of libpq and depending libs from source on Centos 5 + +set -euo pipefail +set -x + +# Last release: https://www.postgresql.org/ftp/source/ +# IMPORTANT! Change the cache key in packages.yml when upgrading libraries +postgres_version="${LIBPQ_VERSION:-15.0}" + +# last release: https://www.openssl.org/source/ +openssl_version="${OPENSSL_VERSION:-1.1.1r}" + +# last release: https://openldap.org/software/download/ +ldap_version="2.6.3" + +# last release: https://github.com/cyrusimap/cyrus-sasl/releases +sasl_version="2.1.28" + +export LIBPQ_BUILD_PREFIX=${LIBPQ_BUILD_PREFIX:-/tmp/libpq.build} + +if [[ -f "${LIBPQ_BUILD_PREFIX}/lib/libpq.so" ]]; then + echo "libpq already available: build skipped" >&2 + exit 0 +fi + +source /etc/os-release + +case "$ID" in + centos) + yum update -y + yum install -y zlib-devel krb5-devel pam-devel + ;; + + alpine) + apk upgrade + apk add --no-cache zlib-dev krb5-dev linux-pam-dev openldap-dev + ;; + + *) + echo "$0: unexpected Linux distribution: '$ID'" >&2 + exit 1 + ;; +esac + +if [ "$ID" == "centos" ]; then + + # Build openssl if needed + openssl_tag="OpenSSL_${openssl_version//./_}" + openssl_dir="openssl-${openssl_tag}" + if [ ! -d "${openssl_dir}" ]; then curl -sL \ + https://github.com/openssl/openssl/archive/${openssl_tag}.tar.gz \ + | tar xzf - + + cd "${openssl_dir}" + + ./config --prefix=${LIBPQ_BUILD_PREFIX} --openssldir=${LIBPQ_BUILD_PREFIX} \ + zlib -fPIC shared + make depend + make + else + cd "${openssl_dir}" + fi + + # Install openssl + make install_sw + cd .. + +fi + + +if [ "$ID" == "centos" ]; then + + # Build libsasl2 if needed + # The system package (cyrus-sasl-devel) causes an amazing error on i686: + # "unsupported version 0 of Verneed record" + # https://github.com/pypa/manylinux/issues/376 + sasl_tag="cyrus-sasl-${sasl_version}" + sasl_dir="cyrus-sasl-${sasl_tag}" + if [ ! -d "${sasl_dir}" ]; then + curl -sL \ + https://github.com/cyrusimap/cyrus-sasl/archive/${sasl_tag}.tar.gz \ + | tar xzf - + + cd "${sasl_dir}" + + autoreconf -i + ./configure --prefix=${LIBPQ_BUILD_PREFIX} \ + CPPFLAGS=-I${LIBPQ_BUILD_PREFIX}/include/ LDFLAGS=-L${LIBPQ_BUILD_PREFIX}/lib + make + else + cd "${sasl_dir}" + fi + + # Install libsasl2 + # requires missing nroff to build + touch saslauthd/saslauthd.8 + make install + cd .. + +fi + + +if [ "$ID" == "centos" ]; then + + # Build openldap if needed + ldap_tag="${ldap_version}" + ldap_dir="openldap-${ldap_tag}" + if [ ! -d "${ldap_dir}" ]; then + curl -sL \ + https://www.openldap.org/software/download/OpenLDAP/openldap-release/openldap-${ldap_tag}.tgz \ + | tar xzf - + + cd "${ldap_dir}" + + ./configure --prefix=${LIBPQ_BUILD_PREFIX} --enable-backends=no --enable-null \ + CPPFLAGS=-I${LIBPQ_BUILD_PREFIX}/include/ LDFLAGS=-L${LIBPQ_BUILD_PREFIX}/lib + + make depend + make -C libraries/liblutil/ + make -C libraries/liblber/ + make -C libraries/libldap/ + else + cd "${ldap_dir}" + fi + + # Install openldap + make -C libraries/liblber/ install + make -C libraries/libldap/ install + make -C include/ install + chmod +x ${LIBPQ_BUILD_PREFIX}/lib/{libldap,liblber}*.so* + cd .. + +fi + + +# Build libpq if needed +postgres_tag="REL_${postgres_version//./_}" +postgres_dir="postgres-${postgres_tag}" +if [ ! -d "${postgres_dir}" ]; then + curl -sL \ + https://github.com/postgres/postgres/archive/${postgres_tag}.tar.gz \ + | tar xzf - + + cd "${postgres_dir}" + + # Match the default unix socket dir default with what defined on Ubuntu and + # Red Hat, which seems the most common location + sed -i 's|#define DEFAULT_PGSOCKET_DIR .*'\ +'|#define DEFAULT_PGSOCKET_DIR "/var/run/postgresql"|' \ + src/include/pg_config_manual.h + + # Often needed, but currently set by the workflow + # export LD_LIBRARY_PATH="${LIBPQ_BUILD_PREFIX}/lib" + + ./configure --prefix=${LIBPQ_BUILD_PREFIX} --sysconfdir=/etc/postgresql-common \ + --without-readline --with-gssapi --with-openssl --with-pam --with-ldap \ + CPPFLAGS=-I${LIBPQ_BUILD_PREFIX}/include/ LDFLAGS=-L${LIBPQ_BUILD_PREFIX}/lib + make -C src/interfaces/libpq + make -C src/bin/pg_config + make -C src/include +else + cd "${postgres_dir}" +fi + +# Install libpq +make -C src/interfaces/libpq install +make -C src/bin/pg_config install +make -C src/include install +cd .. + +find ${LIBPQ_BUILD_PREFIX} -name \*.so.\* -type f -exec strip --strip-unneeded {} \; diff --git a/tools/build/build_macos_arm64.sh b/tools/build/build_macos_arm64.sh new file mode 100755 index 0000000..f8c2fd7 --- /dev/null +++ b/tools/build/build_macos_arm64.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +# Build psycopg-binary wheel packages for Apple M1 (cpNNN-macosx_arm64) +# +# This script is designed to run on Scaleway Apple Silicon machines. +# +# The script cannot be run as sudo (installing brew fails), but requires sudo, +# so it can pretty much only be executed by a sudo user as it is. + +set -euo pipefail +set -x + +python_versions="3.8.10 3.9.13 3.10.5 3.11.0" +pg_version=15 + +# Move to the root of the project +dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "${dir}/../../" + +# Add /usr/local/bin to the path. It seems it's not, in non-interactive sessions +if ! (echo $PATH | grep -q '/usr/local/bin'); then + export PATH=/usr/local/bin:$PATH +fi + +# Install brew, if necessary. Otherwise just make sure it's in the path +if [[ -x /opt/homebrew/bin/brew ]]; then + eval "$(/opt/homebrew/bin/brew shellenv)" +else + command -v brew > /dev/null || ( + # Not necessary: already installed + # xcode-select --install + NONINTERACTIVE=1 /bin/bash -c "$(curl -fsSL \ + https://raw.githubusercontent.com/Homebrew/install/master/install.sh)" + ) + eval "$(/opt/homebrew/bin/brew shellenv)" +fi + +export PGDATA=/opt/homebrew/var/postgresql@${pg_version} + +# Install PostgreSQL, if necessary +command -v pg_config > /dev/null || ( + brew install postgresql@${pg_version} +) + +# After PostgreSQL 15, the bin path is not in the path. +export PATH=$(ls -d1 /opt/homebrew/Cellar/postgresql@${pg_version}/*/bin):$PATH + +# Make sure the server is running + +# Currently not working +# brew services start postgresql@${pg_version} + +if ! pg_ctl status; then + pg_ctl -l /opt/homebrew/var/log/postgresql@${pg_version}.log start +fi + + +# Install the Python versions we want to build +for ver3 in $python_versions; do + ver2=$(echo $ver3 | sed 's/\([^\.]*\)\(\.[^\.]*\)\(.*\)/\1\2/') + command -v python${ver2} > /dev/null || ( + (cd /tmp && + curl -fsSl -O \ + https://www.python.org/ftp/python/${ver3}/python-${ver3}-macos11.pkg) + sudo installer -pkg /tmp/python-${ver3}-macos11.pkg -target / + ) +done + +# Create a virtualenv where to work +if [[ ! -x .venv/bin/python ]]; then + python3 -m venv .venv +fi + +source .venv/bin/activate +pip install cibuildwheel + +# Create the psycopg_binary source package +rm -rf psycopg_binary +python tools/build/copy_to_binary.py + +# Build the binary packages +export CIBW_PLATFORM=macos +export CIBW_ARCHS=arm64 +export CIBW_BUILD='cp{38,39,310,311}-*' +export CIBW_TEST_REQUIRES="./psycopg[test] ./psycopg_pool" +export CIBW_TEST_COMMAND="pytest {project}/tests -m 'not slow and not flakey' --color yes" + +export PSYCOPG_IMPL=binary +export PSYCOPG_TEST_DSN="dbname=postgres" +export PSYCOPG_TEST_WANT_LIBPQ_BUILD=">= ${pg_version}" +export PSYCOPG_TEST_WANT_LIBPQ_IMPORT=">= ${pg_version}" + +cibuildwheel psycopg_binary diff --git a/tools/build/ci_test.sh b/tools/build/ci_test.sh new file mode 100755 index 0000000..d1d2ee4 --- /dev/null +++ b/tools/build/ci_test.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Run the tests in Github Action +# +# Failed tests run up to three times, to take into account flakey tests. +# Of course the random generator is not re-seeded between runs, in order to +# repeat the same result. + +set -euo pipefail +set -x + +# Assemble a markers expression from the MARKERS and NOT_MARKERS env vars +markers="" +for m in ${MARKERS:-}; do + [[ "$markers" != "" ]] && markers="$markers and" + markers="$markers $m" +done +for m in ${NOT_MARKERS:-}; do + [[ "$markers" != "" ]] && markers="$markers and" + markers="$markers not $m" +done + +pytest="python -bb -m pytest --color=yes" + +$pytest -m "$markers" "$@" && exit 0 + +$pytest -m "$markers" --lf --randomly-seed=last "$@" && exit 0 + +$pytest -m "$markers" --lf --randomly-seed=last "$@" diff --git a/tools/build/copy_to_binary.py b/tools/build/copy_to_binary.py new file mode 100755 index 0000000..7cab25c --- /dev/null +++ b/tools/build/copy_to_binary.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +# Create the psycopg-binary package by renaming and patching psycopg-c + +import os +import re +import shutil +from pathlib import Path +from typing import Union + +curdir = Path(__file__).parent +pdir = curdir / "../.." +target = pdir / "psycopg_binary" + +if target.exists(): + raise Exception(f"path {target} already exists") + + +def sed_i(pattern: str, repl: str, filename: Union[str, Path]) -> None: + with open(filename, "rb") as f: + data = f.read() + newdata = re.sub(pattern.encode("utf8"), repl.encode("utf8"), data) + if newdata != data: + with open(filename, "wb") as f: + f.write(newdata) + + +shutil.copytree(pdir / "psycopg_c", target) +shutil.move(str(target / "psycopg_c"), str(target / "psycopg_binary")) +shutil.move(str(target / "README-binary.rst"), str(target / "README.rst")) +sed_i("psycopg-c", "psycopg-binary", target / "setup.cfg") +sed_i( + r"__impl__\s*=.*", '__impl__ = "binary"', target / "psycopg_binary/pq.pyx" +) +for dirpath, dirnames, filenames in os.walk(target): + for filename in filenames: + if os.path.splitext(filename)[1] not in (".pyx", ".pxd", ".py"): + continue + sed_i(r"\bpsycopg_c\b", "psycopg_binary", Path(dirpath) / filename) diff --git a/tools/build/print_so_versions.sh b/tools/build/print_so_versions.sh new file mode 100755 index 0000000..a3c4ecd --- /dev/null +++ b/tools/build/print_so_versions.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Take a .so file as input and print the Debian packages and versions of the +# libraries it links. + +set -euo pipefail +# set -x + +source /etc/os-release + +sofile="$1" + +case "$ID" in + alpine) + depfiles=$( (ldd "$sofile" 2>/dev/null || true) | grep '=>' | sed 's/.*=> \(.*\) (.*)/\1/') + (for depfile in $depfiles; do + echo "$(basename "$depfile") => $(apk info --who-owns "${depfile}" | awk '{print $(NF)}')" + done) | sort | uniq + ;; + + debian) + depfiles=$(ldd "$sofile" | grep '=>' | sed 's/.*=> \(.*\) (.*)/\1/') + (for depfile in $depfiles; do + pkgname=$(dpkg -S "${depfile}" | sed 's/\(\): .*/\1/') + dpkg -l "${pkgname}" | grep '^ii' | awk '{print $2 " => " $3}' + done) | sort | uniq + ;; + + centos) + echo "TODO!" + ;; + + *) + echo "$0: unexpected Linux distribution: '$ID'" >&2 + exit 1 + ;; +esac diff --git a/tools/build/run_build_macos_arm64.sh b/tools/build/run_build_macos_arm64.sh new file mode 100755 index 0000000..f5ae617 --- /dev/null +++ b/tools/build/run_build_macos_arm64.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Build psycopg-binary wheel packages for Apple M1 (cpNNN-macosx_arm64) +# +# This script is designed to run on a local machine: it will clone the repos +# remotely and execute the `build_macos_arm64.sh` script remotely, then will +# download the built packages. A tag to build must be specified. +# +# In order to run the script, the `m1` host must be specified in +# `~/.ssh/config`; for instance: +# +# Host m1 +# User m1 +# HostName 1.2.3.4 + +set -euo pipefail +# set -x + +tag=${1:-} + +if [[ ! "${tag}" ]]; then + echo "Usage: $0 TAG" >&2 + exit 2 +fi + +rdir=psycobuild + +# Clone the repos +ssh m1 rm -rf "${rdir}" +ssh m1 git clone https://github.com/psycopg/psycopg.git --branch ${tag} "${rdir}" + +# Allow sudoing without password, to allow brew to install +ssh -t m1 bash -c \ + 'test -f /etc/sudoers.d/m1 || echo "m1 ALL=(ALL) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/m1' + +# Build the wheel packages +ssh m1 "${rdir}/tools/build/build_macos_arm64.sh" + +# Transfer the packages locally +scp -r "m1:${rdir}/wheelhouse" . diff --git a/tools/build/strip_wheel.sh b/tools/build/strip_wheel.sh new file mode 100755 index 0000000..bfcd302 --- /dev/null +++ b/tools/build/strip_wheel.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Strip symbols inplace from the libraries in a zip archive. +# +# Stripping symbols is beneficial (reduction of 30% of the final package, > +# %90% of the installed libraries. However just running `auditwheel repair +# --strip` breaks some of the libraries included from the system, which fail at +# import with errors such as "ELF load command address/offset not properly +# aligned". +# +# System libraries are already pretty stripped. Ours go around 24Mb -> 1.5Mb... +# +# This script is designed to run on a wheel archive before auditwheel. + +set -euo pipefail +# set -x + +source /etc/os-release +dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +wheel=$(realpath "$1") +shift + +tmpdir=$(mktemp -d) +trap "rm -r ${tmpdir}" EXIT + +cd "${tmpdir}" +python -m zipfile -e "${wheel}" . + +echo " +Libs before:" +# Busybox doesn't have "find -ls" +find . -name \*.so | xargs ls -l + +# On Debian, print the package versions libraries come from +echo " +Dependencies versions of '_psycopg.so' library:" +"${dir}/print_so_versions.sh" "$(find . -name \*_psycopg\*.so)" + +find . -name \*.so -exec strip "$@" {} \; + +echo " +Libs after:" +find . -name \*.so | xargs ls -l + +python -m zipfile -c ${wheel} * + +cd - diff --git a/tools/build/wheel_linux_before_all.sh b/tools/build/wheel_linux_before_all.sh new file mode 100755 index 0000000..663e3ef --- /dev/null +++ b/tools/build/wheel_linux_before_all.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Configure the libraries needed to build wheel packages on linux. +# This script is designed to be used by cibuildwheel as CIBW_BEFORE_ALL_LINUX + +set -euo pipefail +set -x + +dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +source /etc/os-release + +# Install PostgreSQL development files. +case "$ID" in + alpine) + # tzdata is required for datetime tests. + apk update + apk add --no-cache tzdata + "${dir}/build_libpq.sh" > /dev/null + ;; + + debian) + # Note that the pgdg doesn't have an aarch64 repository so wheels are + # build with the libpq packaged with Debian 9, which is 9.6. + if [ "$AUDITWHEEL_ARCH" != 'aarch64' ]; then + echo "deb http://apt.postgresql.org/pub/repos/apt $VERSION_CODENAME-pgdg main" \ + > /etc/apt/sources.list.d/pgdg.list + # TODO: On 2021-11-09 curl fails on 'ppc64le' with: + # curl: (60) SSL certificate problem: certificate has expired + # Test again later if -k can be removed. + curl -skf https://www.postgresql.org/media/keys/ACCC4CF8.asc \ + > /etc/apt/trusted.gpg.d/postgresql.asc + fi + + apt-get update + apt-get -y upgrade + apt-get -y install libpq-dev + ;; + + centos) + "${dir}/build_libpq.sh" > /dev/null + ;; + + *) + echo "$0: unexpected Linux distribution: '$ID'" >&2 + exit 1 + ;; +esac diff --git a/tools/build/wheel_macos_before_all.sh b/tools/build/wheel_macos_before_all.sh new file mode 100755 index 0000000..285a063 --- /dev/null +++ b/tools/build/wheel_macos_before_all.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Configure the environment needed to build wheel packages on Mac OS. +# This script is designed to be used by cibuildwheel as CIBW_BEFORE_ALL_MACOS + +set -euo pipefail +set -x + +dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +brew update +brew install gnu-sed postgresql@14 +# Fetch 14.1 if 14.0 is still the default version +brew reinstall postgresql + +# Start the database for testing +brew services start postgresql + +# Wait for postgres to come up +for i in $(seq 10 -1 0); do + eval pg_isready && break + if [ $i == 0 ]; then + echo "PostgreSQL service not ready, giving up" + exit 1 + fi + echo "PostgreSQL service not ready, waiting a bit, attempts left: $i" + sleep 5 +done diff --git a/tools/build/wheel_win32_before_build.bat b/tools/build/wheel_win32_before_build.bat new file mode 100644 index 0000000..fd35f5d --- /dev/null +++ b/tools/build/wheel_win32_before_build.bat @@ -0,0 +1,3 @@ +@echo on +pip install delvewheel +choco upgrade postgresql diff --git a/tools/bump_version.py b/tools/bump_version.py new file mode 100755 index 0000000..50dbe0b --- /dev/null +++ b/tools/bump_version.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python +"""Bump the version number of the project. +""" + +from __future__ import annotations + +import re +import sys +import logging +import subprocess as sp +from enum import Enum +from pathlib import Path +from argparse import ArgumentParser, Namespace +from functools import cached_property +from dataclasses import dataclass + +from packaging.version import parse as parse_version, Version + +PROJECT_DIR = Path(__file__).parent.parent + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +@dataclass +class Package: + name: str + version_files: list[Path] + history_file: Path + tag_format: str + + def __post_init__(self) -> None: + packages[self.name] = self + + +packages: dict[str, Package] = {} + +Package( + name="psycopg", + version_files=[ + PROJECT_DIR / "psycopg/psycopg/version.py", + PROJECT_DIR / "psycopg_c/psycopg_c/version.py", + ], + history_file=PROJECT_DIR / "docs/news.rst", + tag_format="{version}", +) + +Package( + name="psycopg_pool", + version_files=[PROJECT_DIR / "psycopg_pool/psycopg_pool/version.py"], + history_file=PROJECT_DIR / "docs/news_pool.rst", + tag_format="pool-{version}", +) + + +class Bumper: + def __init__(self, package: Package, *, bump_level: str | BumpLevel): + self.package = package + self.bump_level = BumpLevel(bump_level) + + self._version_regex = re.compile( + r"""(?ix) + ^ + (?P<pre>__version__\s*=\s*(?P<quote>["'])) + (?P<ver>[^'"]+) + (?P<post>(?P=quote)\s*(?:\#.*)?) + $ + """ + ) + + @cached_property + def current_version(self) -> Version: + versions = set( + self._parse_version_from_file(f) for f in self.package.version_files + ) + if len(versions) > 1: + raise ValueError( + f"inconsistent versions ({', '.join(map(str, sorted(versions)))})" + f" in {self.package.version_files}" + ) + + return versions.pop() + + @cached_property + def want_version(self) -> Version: + current = self.current_version + parts = [current.major, current.minor, current.micro, current.dev or 0] + + match self.bump_level: + case BumpLevel.MAJOR: + # 1.2.3 -> 2.0.0 + parts[0] += 1 + parts[1] = parts[2] = parts[3] = 0 + case BumpLevel.MINOR: + # 1.2.3 -> 1.3.0 + parts[1] += 1 + parts[2] = parts[3] = 0 + case BumpLevel.PATCH: + # 1.2.3 -> 1.2.4 + # 1.2.3.dev4 -> 1.2.3 + if parts[3] == 0: + parts[2] += 1 + else: + parts[3] = 0 + case BumpLevel.DEV: + # 1.2.3 -> 1.2.4.dev1 + # 1.2.3.dev1 -> 1.2.3.dev2 + if parts[3] == 0: + parts[2] += 1 + parts[3] += 1 + + sparts = [str(part) for part in parts[:3]] + if parts[3]: + sparts.append(f"dev{parts[3]}") + return Version(".".join(sparts)) + + def update_files(self) -> None: + for f in self.package.version_files: + self._update_version_in_file(f, self.want_version) + + if self.bump_level != BumpLevel.DEV: + self._update_history_file(self.package.history_file, self.want_version) + + def commit(self) -> None: + logger.debug("committing version changes") + msg = f"""\ +chore: bump {self.package.name} package version to {self.want_version} +""" + files = self.package.version_files + [self.package.history_file] + cmdline = ["git", "commit", "-m", msg] + list(map(str, files)) + sp.check_call(cmdline) + + def create_tag(self) -> None: + logger.debug("tagging version %s", self.want_version) + tag_name = self.package.tag_format.format(version=self.want_version) + changes = self._get_changes_lines( + self.package.history_file, + self.want_version, + ) + msg = f"""\ +{self.package.name} {self.want_version} released + +{''.join(changes)} +""" + cmdline = ["git", "tag", "-a", "-s", "-m", msg, tag_name] + sp.check_call(cmdline) + + def _parse_version_from_file(self, fp: Path) -> Version: + logger.debug("looking for version in %s", fp) + matches = [] + with fp.open() as f: + for line in f: + m = self._version_regex.match(line) + if m: + matches.append(m) + + if not matches: + raise ValueError(f"no version found in {fp}") + elif len(matches) > 1: + raise ValueError(f"more than one version found in {fp}") + + vs = parse_version(matches[0].group("ver")) + assert isinstance(vs, Version) + return vs + + def _update_version_in_file(self, fp: Path, version: Version) -> None: + logger.debug("upgrading version to %s in %s", version, fp) + lines = [] + with fp.open() as f: + for line in f: + if self._version_regex.match(line): + line = self._version_regex.sub(f"\\g<pre>{version}\\g<post>", line) + lines.append(line) + + with fp.open("w") as f: + for line in lines: + f.write(line) + + def _update_history_file(self, fp: Path, version: Version) -> None: + logger.debug("upgrading history file %s", fp) + with fp.open() as f: + lines = f.readlines() + + vln: int = -1 + lns = self._find_lines( + r"^[^\s]+ " + re.escape(str(version)) + r"\s*\(unreleased\)?$", lines + ) + assert len(lns) <= 1 + if len(lns) == 1: + vln = lns[0] + lines[vln] = lines[vln].rsplit(None, 1)[0] + lines[vln + 1] = lines[vln + 1][0] * len(lines[lns[0]]) + + lns = self._find_lines("^Future", lines) + assert len(lns) <= 1 + if len(lns) == 1: + del lines[lns[0] : lns[0] + 3] + if vln > lns[0]: + vln -= 3 + + lns = self._find_lines("^Current", lines) + assert len(lns) <= 1 + if len(lns) == 1 and vln >= 0: + clines = lines[lns[0] : lns[0] + 3] + del lines[lns[0] : lns[0] + 3] + if vln > lns[0]: + vln -= 3 + lines[vln:vln] = clines + + with fp.open("w") as f: + for line in lines: + f.write(line) + if not line.endswith("\n"): + f.write("\n") + + def _get_changes_lines(self, fp: Path, version: Version) -> list[str]: + with fp.open() as f: + lines = f.readlines() + + lns = self._find_lines(r"^[^\s]+ " + re.escape(str(version)), lines) + assert len(lns) == 1 + start = end = lns[0] + 3 + while lines[end].rstrip(): + end += 1 + + return lines[start:end] + + def _find_lines(self, pattern: str, lines: list[str]) -> list[int]: + rv = [] + rex = re.compile(pattern) + for i, line in enumerate(lines): + if rex.match(line): + rv.append(i) + + return rv + + +def main() -> int | None: + opt = parse_cmdline() + logger.setLevel(opt.loglevel) + bumper = Bumper(packages[opt.package], bump_level=opt.level) + logger.info("current version: %s", bumper.current_version) + logger.info("bumping to version: %s", bumper.want_version) + if not opt.dry_run: + bumper.update_files() + bumper.commit() + if opt.level != BumpLevel.DEV: + bumper.create_tag() + + return 0 + + +class BumpLevel(str, Enum): + MAJOR = "major" + MINOR = "minor" + PATCH = "patch" + DEV = "dev" + + +def parse_cmdline() -> Namespace: + parser = ArgumentParser(description=__doc__) + + parser.add_argument( + "--level", + choices=[level.value for level in BumpLevel], + default=BumpLevel.PATCH.value, + type=BumpLevel, + help="the level to bump [default: %(default)s]", + ) + + parser.add_argument( + "--package", + choices=list(packages.keys()), + default="psycopg", + help="the package to bump version [default: %(default)s]", + ) + + parser.add_argument( + "-n", + "--dry-run", + help="Just pretend", + action="store_true", + ) + + g = parser.add_mutually_exclusive_group() + g.add_argument( + "-q", + "--quiet", + help="Talk less", + dest="loglevel", + action="store_const", + const=logging.WARN, + default=logging.INFO, + ) + g.add_argument( + "-v", + "--verbose", + help="Talk more", + dest="loglevel", + action="store_const", + const=logging.DEBUG, + default=logging.INFO, + ) + opt = parser.parse_args() + + return opt + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/update_backer.py b/tools/update_backer.py new file mode 100755 index 0000000..0088527 --- /dev/null +++ b/tools/update_backer.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +r"""Add or edit github users in the backers file +""" + +import sys +import logging +import requests +from pathlib import Path +from ruamel.yaml import YAML # pip install ruamel.yaml + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def fetch_user(username): + logger.info("fetching %s", username) + resp = requests.get( + f"https://api.github.com/users/{username}", + headers={"Accept": "application/vnd.github.v3+json"}, + ) + resp.raise_for_status() + return resp.json() + + +def get_user_data(data): + """ + Get the data to save from the request data + """ + out = { + "username": data["login"], + "avatar": data["avatar_url"], + "name": data["name"], + } + if data["blog"]: + website = data["blog"] + if not website.startswith("http"): + website = "http://" + website + + out["website"] = website + + return out + + +def add_entry(opt, filedata, username): + userdata = get_user_data(fetch_user(username)) + if opt.top: + userdata["tier"] = "top" + + filedata.append(userdata) + + +def update_entry(opt, filedata, entry): + # entry is an username or an user entry daat + if isinstance(entry, str): + username = entry + entry = [e for e in filedata if e["username"] == username] + if not entry: + raise Exception(f"{username} not found") + entry = entry[0] + else: + username = entry["username"] + + userdata = get_user_data(fetch_user(username)) + for k, v in userdata.items(): + if entry.get("keep_" + k): + continue + entry[k] = v + + +def main(): + opt = parse_cmdline() + logger.info("reading %s", opt.file) + yaml = YAML(typ="rt") + filedata = yaml.load(opt.file) + + for username in opt.add or (): + add_entry(opt, filedata, username) + + for username in opt.update or (): + update_entry(opt, filedata, username) + + if opt.update_all: + for entry in filedata: + update_entry(opt, filedata, entry) + + # yamllint happy + yaml.explicit_start = True + logger.info("writing %s", opt.file) + yaml.dump(filedata, opt.file) + + +def parse_cmdline(): + from argparse import ArgumentParser + + parser = ArgumentParser(description=__doc__) + parser.add_argument( + "--file", + help="the file to update [default: %(default)s]", + default=Path(__file__).parent.parent / "BACKERS.yaml", + type=Path, + ) + parser.add_argument( + "--add", + metavar="USERNAME", + nargs="+", + help="add USERNAME to the backers", + ) + + parser.add_argument( + "--top", + action="store_true", + help="add to the top tier", + ) + + parser.add_argument( + "--update", + metavar="USERNAME", + nargs="+", + help="update USERNAME data", + ) + + parser.add_argument( + "--update-all", + action="store_true", + help="update all the existing backers data", + ) + + opt = parser.parse_args() + + return opt + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/update_errors.py b/tools/update_errors.py new file mode 100755 index 0000000..638d352 --- /dev/null +++ b/tools/update_errors.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# type: ignore +""" +Generate per-sqlstate errors from PostgreSQL source code. + +The script can be run at a new PostgreSQL release to refresh the module. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +import sys +import logging +from urllib.request import urlopen +from collections import defaultdict, namedtuple + +from psycopg.errors import get_base_exception + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def main(): + classes, errors = fetch_errors("9.6 10 11 12 13 14 15".split()) + + fn = os.path.dirname(__file__) + "/../psycopg/psycopg/errors.py" + update_file(fn, generate_module_data(classes, errors)) + + fn = os.path.dirname(__file__) + "/../docs/api/errors.rst" + update_file(fn, generate_docs_data(classes, errors)) + + +def parse_errors_txt(url): + classes = {} + errors = defaultdict(dict) + + page = urlopen(url) + for line in page.read().decode("ascii").splitlines(): + # Strip comments and skip blanks + line = line.split("#")[0].strip() + if not line: + continue + + # Parse a section + m = re.match(r"Section: (Class (..) - .+)", line) + if m: + label, class_ = m.groups() + classes[class_] = label + continue + + # Parse an error + m = re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line) + if m: + sqlstate, macro, spec = m.groups() + # skip sqlstates without specs as they are not publicly visible + if not spec: + continue + errlabel = spec.upper() + errors[class_][sqlstate] = errlabel + continue + + # We don't expect anything else + raise ValueError("unexpected line:\n%s" % line) + + return classes, errors + + +errors_txt_url = ( + "http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob_plain;" + "f=src/backend/utils/errcodes.txt;hb=%s" +) + + +Error = namedtuple("Error", "sqlstate errlabel clsname basename") + + +def fetch_errors(versions): + classes = {} + errors = defaultdict(dict) + + for version in versions: + logger.info("fetching errors from version %s", version) + tver = tuple(map(int, version.split()[0].split("."))) + tag = "%s%s_STABLE" % ( + (tver[0] >= 10 and "REL_" or "REL"), + version.replace(".", "_"), + ) + c1, e1 = parse_errors_txt(errors_txt_url % tag) + classes.update(c1) + + for c, cerrs in e1.items(): + errors[c].update(cerrs) + + # clean up data + + # success and warning - never raised + del classes["00"] + del classes["01"] + del errors["00"] + del errors["01"] + + specific = { + "38002": "ModifyingSqlDataNotPermittedExt", + "38003": "ProhibitedSqlStatementAttemptedExt", + "38004": "ReadingSqlDataNotPermittedExt", + "39004": "NullValueNotAllowedExt", + "XX000": "InternalError_", + } + + seen = set( + """ + Error Warning InterfaceError DataError DatabaseError ProgrammingError + IntegrityError InternalError NotSupportedError OperationalError + """.split() + ) + + for c, cerrs in errors.items(): + for sqstate, errlabel in list(cerrs.items()): + if sqstate in specific: + clsname = specific[sqstate] + else: + clsname = errlabel.title().replace("_", "") + if clsname in seen: + raise Exception("class already existing: %s" % clsname) + seen.add(clsname) + + basename = get_base_exception(sqstate).__name__ + cerrs[sqstate] = Error(sqstate, errlabel, clsname, basename) + + return classes, errors + + +def generate_module_data(classes, errors): + yield "" + + for clscode, clslabel in sorted(classes.items()): + yield f""" +# {clslabel} +""" + for _, e in sorted(errors[clscode].items()): + yield f"""\ +class {e.clsname}({e.basename}, + code={e.sqlstate!r}, name={e.errlabel!r}): + pass +""" + yield "" + + +def generate_docs_data(classes, errors): + Line = namedtuple("Line", "colstate colexc colbase, sqlstate") + lines = [Line("SQLSTATE", "Exception", "Base exception", None)] + + for clscode in sorted(classes): + for _, error in sorted(errors[clscode].items()): + lines.append( + Line( + f"``{error.sqlstate}``", + f"`!{error.clsname}`", + f"`!{error.basename}`", + error.sqlstate, + ) + ) + + widths = [max(len(line[c]) for line in lines) for c in range(3)] + h = Line(*(["=" * w for w in widths] + [None])) + lines.insert(0, h) + lines.insert(2, h) + lines.append(h) + + h1 = "-" * (sum(widths) + len(widths) - 1) + sqlclass = None + + yield "" + for line in lines: + cls = line.sqlstate[:2] if line.sqlstate else None + if cls and cls != sqlclass: + yield re.sub(r"(Class\s+[^\s]+)", r"**\1**", classes[cls]) + yield h1 + sqlclass = cls + + yield ( + "%-*s %-*s %-*s" + % ( + widths[0], + line.colstate, + widths[1], + line.colexc, + widths[2], + line.colbase, + ) + ).rstrip() + + yield "" + + +def update_file(fn, new_lines): + logger.info("updating %s", fn) + + with open(fn, "r") as f: + lines = f.read().splitlines() + + istart, iend = [ + i + for i, line in enumerate(lines) + if re.match(r"\s*(#|\.\.)\s*autogenerated:\s+(start|end)", line) + ] + + lines[istart + 1 : iend] = new_lines + + with open(fn, "w") as f: + for line in lines: + f.write(line + "\n") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/update_oids.py b/tools/update_oids.py new file mode 100755 index 0000000..df4f969 --- /dev/null +++ b/tools/update_oids.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +""" +Update the maps of builtin types and names. + +This script updates some of the files in psycopg source code with data read +from a database catalog. + +Hint: use docker to upgrade types from a new version in isolation. Run: + + docker run --rm -p 11111:5432 --name pg -e POSTGRES_PASSWORD=password postgres:TAG + +with a specified version tag, and then query it using: + + %(prog)s "host=localhost port=11111 user=postgres password=password" +""" + +import re +import argparse +import subprocess as sp +from typing import List +from pathlib import Path +from typing_extensions import TypeAlias + +import psycopg +from psycopg.rows import TupleRow +from psycopg.crdb import CrdbConnection + +Connection: TypeAlias = psycopg.Connection[TupleRow] + +ROOT = Path(__file__).parent.parent + + +def main() -> None: + opt = parse_cmdline() + conn = psycopg.connect(opt.dsn, autocommit=True) + + if CrdbConnection.is_crdb(conn): + conn = CrdbConnection.connect(opt.dsn, autocommit=True) + update_crdb_python_oids(conn) + else: + update_python_oids(conn) + update_cython_oids(conn) + + +def update_python_oids(conn: Connection) -> None: + fn = ROOT / "psycopg/psycopg/postgres.py" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_py_types(conn)) + lines.extend(get_py_ranges(conn)) + lines.extend(get_py_multiranges(conn)) + + update_file(fn, lines) + sp.check_call(["black", "-q", fn]) + + +def update_cython_oids(conn: Connection) -> None: + fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_cython_oids(conn)) + + update_file(fn, lines) + + +def update_crdb_python_oids(conn: Connection) -> None: + fn = ROOT / "psycopg/psycopg/crdb/_types.py" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_py_types(conn)) + + update_file(fn, lines) + sp.check_call(["black", "-q", fn]) + + +def get_version_comment(conn: Connection) -> List[str]: + if conn.info.vendor == "PostgreSQL": + # Assume PG > 10 + num = conn.info.server_version + version = f"{num // 10000}.{num % 100}" + elif conn.info.vendor == "CockroachDB": + assert isinstance(conn, CrdbConnection) + num = conn.info.server_version + version = f"{num // 10000}.{num % 10000 // 100}.{num % 100}" + else: + raise NotImplementedError(f"unexpected vendor: {conn.info.vendor}") + return ["", f" # Generated from {conn.info.vendor} {version}", ""] + + +def get_py_types(conn: Connection) -> List[str]: + # Note: "record" is a pseudotype but still a useful one to have. + # "pg_lsn" is a documented public type and useful in streaming replication + lines = [] + for (typname, oid, typarray, regtype, typdelim) in conn.execute( + """ +select typname, oid, typarray, + -- CRDB might have quotes in the regtype representation + replace(typname::regtype::text, '''', '') as regtype, + typdelim +from pg_type t +where + oid < 10000 + and oid != '"char"'::regtype + and (typtype = 'b' or typname = 'record') + and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') +order by typname +""" + ): + # Weird legacy type in postgres catalog + if typname == "char": + typname = regtype = '"char"' + + # https://github.com/cockroachdb/cockroach/issues/81645 + if typname == "int4" and conn.info.vendor == "CockroachDB": + regtype = typname + + params = [f"{typname!r}, {oid}, {typarray}"] + if regtype != typname: + params.append(f"regtype={regtype!r}") + if typdelim != ",": + params.append(f"delimiter={typdelim!r}") + lines.append(f"TypeInfo({','.join(params)}),") + + return lines + + +def get_py_ranges(conn: Connection) -> List[str]: + lines = [] + for (typname, oid, typarray, rngsubtype) in conn.execute( + """ +select typname, oid, typarray, rngsubtype +from + pg_type t + join pg_range r on t.oid = rngtypid +where + oid < 10000 + and typtype = 'r' +order by typname +""" + ): + params = [f"{typname!r}, {oid}, {typarray}, subtype_oid={rngsubtype}"] + lines.append(f"RangeInfo({','.join(params)}),") + + return lines + + +def get_py_multiranges(conn: Connection) -> List[str]: + lines = [] + for (typname, oid, typarray, rngtypid, rngsubtype) in conn.execute( + """ +select typname, oid, typarray, rngtypid, rngsubtype +from + pg_type t + join pg_range r on t.oid = rngmultitypid +where + oid < 10000 + and typtype = 'm' +order by typname +""" + ): + params = [ + f"{typname!r}, {oid}, {typarray}," + f" range_oid={rngtypid}, subtype_oid={rngsubtype}" + ] + lines.append(f"MultirangeInfo({','.join(params)}),") + + return lines + + +def get_cython_oids(conn: Connection) -> List[str]: + lines = [] + for (typname, oid) in conn.execute( + """ +select typname, oid +from pg_type +where + oid < 10000 + and (typtype = any('{b,r,m}') or typname = 'record') + and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') +order by typname +""" + ): + const_name = typname.upper() + "_OID" + lines.append(f" {const_name} = {oid}") + + return lines + + +def update_file(fn: Path, new: List[str]) -> None: + with fn.open("r") as f: + lines = f.read().splitlines() + istart, iend = [ + i + for i, line in enumerate(lines) + if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line) + ] + lines[istart + 1 : iend] = new + + with fn.open("w") as f: + f.write("\n".join(lines)) + f.write("\n") + + +def parse_cmdline() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("dsn", help="where to connect to") + opt = parser.parse_args() + return opt + + +if __name__ == "__main__": + main() |