From 708c091a8b4db6a55be1c96ae33ee0da632b269f Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 19 Apr 2024 05:06:41 +0200 Subject: Adding upstream version 4.0.1. Signed-off-by: Daniel Baumann --- .coveragerc | 3 + .editorconfig | 15 + .git-blame-ignore-revs | 0 .github/ISSUE_TEMPLATE.md | 9 + .github/PULL_REQUEST_TEMPLATE.md | 12 + .github/workflows/ci.yml | 102 ++ .github/workflows/codeql.yml | 41 + .gitignore | 74 + .pre-commit-config.yaml | 5 + AUTHORS | 136 ++ DEVELOP.rst | 220 +++ Dockerfile | 6 + LICENSE.txt | 26 + MANIFEST.in | 2 + README.rst | 392 +++++ RELEASES.md | 24 + TODO | 12 + Vagrantfile | 138 ++ changelog.rst | 1203 ++++++++++++++ pgcli-completion.bash | 61 + pgcli/__init__.py | 1 + pgcli/__main__.py | 9 + pgcli/auth.py | 60 + pgcli/completion_refresher.py | 152 ++ pgcli/config.py | 99 ++ pgcli/explain_output_formatter.py | 19 + pgcli/key_bindings.py | 133 ++ pgcli/magic.py | 71 + pgcli/main.py | 1728 +++++++++++++++++++++ pgcli/packages/__init__.py | 0 pgcli/packages/formatter/__init__.py | 1 + pgcli/packages/formatter/sqlformatter.py | 74 + pgcli/packages/parseutils/__init__.py | 58 + pgcli/packages/parseutils/ctes.py | 141 ++ pgcli/packages/parseutils/meta.py | 170 ++ pgcli/packages/parseutils/tables.py | 165 ++ pgcli/packages/parseutils/utils.py | 140 ++ pgcli/packages/pgliterals/__init__.py | 0 pgcli/packages/pgliterals/main.py | 15 + pgcli/packages/pgliterals/pgliterals.json | 630 ++++++++ pgcli/packages/prioritization.py | 51 + pgcli/packages/prompt_utils.py | 37 + pgcli/packages/sqlcompletion.py | 605 ++++++++ pgcli/pgbuffer.py | 61 + pgcli/pgclirc | 236 +++ pgcli/pgcompleter.py | 1072 +++++++++++++ pgcli/pgexecute.py | 868 +++++++++++ pgcli/pgstyle.py | 116 ++ pgcli/pgtoolbar.py | 71 + pgcli/pyev.py | 439 ++++++ post-install | 4 + post-remove | 4 + pylintrc | 2 + pyproject.toml | 21 + release.py | 135 ++ requirements-dev.txt | 13 + sanity_checks.txt | 37 + screenshots/image01.png | Bin 0 -> 82111 bytes screenshots/image02.png | Bin 0 -> 11767 bytes screenshots/kharkiv-destroyed.jpg | Bin 0 -> 187337 bytes screenshots/pgcli.gif | Bin 0 -> 238421 bytes setup.py | 76 + tests/conftest.py | 52 + tests/features/__init__.py | 0 tests/features/auto_vertical.feature | 12 + tests/features/basic_commands.feature | 81 + tests/features/crud_database.feature | 17 + tests/features/crud_table.feature | 45 + tests/features/db_utils.py | 87 ++ tests/features/environment.py | 227 +++ tests/features/expanded.feature | 29 + tests/features/fixture_data/help.txt | 25 + tests/features/fixture_data/help_commands.txt | 64 + tests/features/fixture_data/mock_pg_service.conf | 4 + tests/features/fixture_utils.py | 28 + tests/features/iocommands.feature | 17 + tests/features/named_queries.feature | 10 + tests/features/pgbouncer.feature | 12 + tests/features/specials.feature | 6 + tests/features/steps/__init__.py | 0 tests/features/steps/auto_vertical.py | 99 ++ tests/features/steps/basic_commands.py | 231 +++ tests/features/steps/crud_database.py | 93 ++ tests/features/steps/crud_table.py | 185 +++ tests/features/steps/expanded.py | 70 + tests/features/steps/iocommands.py | 80 + tests/features/steps/named_queries.py | 57 + tests/features/steps/pgbouncer.py | 22 + tests/features/steps/specials.py | 31 + tests/features/steps/wrappers.py | 71 + tests/features/wrappager.py | 16 + tests/formatter/__init__.py | 1 + tests/formatter/test_sqlformatter.py | 111 ++ tests/metadata.py | 255 +++ tests/parseutils/test_ctes.py | 137 ++ tests/parseutils/test_function_metadata.py | 19 + tests/parseutils/test_parseutils.py | 310 ++++ tests/pytest.ini | 2 + tests/test_auth.py | 40 + tests/test_completion_refresher.py | 95 ++ tests/test_config.py | 43 + tests/test_fuzzy_completion.py | 87 ++ tests/test_main.py | 490 ++++++ tests/test_naive_completion.py | 133 ++ tests/test_pgcompleter.py | 76 + tests/test_pgexecute.py | 773 +++++++++ tests/test_pgspecial.py | 78 + tests/test_prioritization.py | 20 + tests/test_prompt_utils.py | 17 + tests/test_rowlimit.py | 79 + tests/test_smart_completion_multiple_schemata.py | 757 +++++++++ tests/test_smart_completion_public_schema_only.py | 1112 +++++++++++++ tests/test_sqlcompletion.py | 964 ++++++++++++ tests/test_ssh_tunnel.py | 188 +++ tests/utils.py | 92 ++ tox.ini | 14 + 116 files changed, 17559 insertions(+) create mode 100644 .coveragerc create mode 100644 .editorconfig create mode 100644 .git-blame-ignore-revs create mode 100644 .github/ISSUE_TEMPLATE.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/codeql.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 AUTHORS create mode 100644 DEVELOP.rst create mode 100644 Dockerfile create mode 100644 LICENSE.txt create mode 100644 MANIFEST.in create mode 100644 README.rst create mode 100644 RELEASES.md create mode 100644 TODO create mode 100644 Vagrantfile create mode 100644 changelog.rst create mode 100644 pgcli-completion.bash create mode 100644 pgcli/__init__.py create mode 100644 pgcli/__main__.py create mode 100644 pgcli/auth.py create mode 100644 pgcli/completion_refresher.py create mode 100644 pgcli/config.py create mode 100644 pgcli/explain_output_formatter.py create mode 100644 pgcli/key_bindings.py create mode 100644 pgcli/magic.py create mode 100644 pgcli/main.py create mode 100644 pgcli/packages/__init__.py create mode 100644 pgcli/packages/formatter/__init__.py create mode 100644 pgcli/packages/formatter/sqlformatter.py create mode 100644 pgcli/packages/parseutils/__init__.py create mode 100644 pgcli/packages/parseutils/ctes.py create mode 100644 pgcli/packages/parseutils/meta.py create mode 100644 pgcli/packages/parseutils/tables.py create mode 100644 pgcli/packages/parseutils/utils.py create mode 100644 pgcli/packages/pgliterals/__init__.py create mode 100644 pgcli/packages/pgliterals/main.py create mode 100644 pgcli/packages/pgliterals/pgliterals.json create mode 100644 pgcli/packages/prioritization.py create mode 100644 pgcli/packages/prompt_utils.py create mode 100644 pgcli/packages/sqlcompletion.py create mode 100644 pgcli/pgbuffer.py create mode 100644 pgcli/pgclirc create mode 100644 pgcli/pgcompleter.py create mode 100644 pgcli/pgexecute.py create mode 100644 pgcli/pgstyle.py create mode 100644 pgcli/pgtoolbar.py create mode 100644 pgcli/pyev.py create mode 100644 post-install create mode 100644 post-remove create mode 100644 pylintrc create mode 100644 pyproject.toml create mode 100644 release.py create mode 100644 requirements-dev.txt create mode 100644 sanity_checks.txt create mode 100644 screenshots/image01.png create mode 100644 screenshots/image02.png create mode 100644 screenshots/kharkiv-destroyed.jpg create mode 100644 screenshots/pgcli.gif create mode 100644 setup.py create mode 100644 tests/conftest.py create mode 100644 tests/features/__init__.py create mode 100644 tests/features/auto_vertical.feature create mode 100644 tests/features/basic_commands.feature create mode 100644 tests/features/crud_database.feature create mode 100644 tests/features/crud_table.feature create mode 100644 tests/features/db_utils.py create mode 100644 tests/features/environment.py create mode 100644 tests/features/expanded.feature create mode 100644 tests/features/fixture_data/help.txt create mode 100644 tests/features/fixture_data/help_commands.txt create mode 100644 tests/features/fixture_data/mock_pg_service.conf create mode 100644 tests/features/fixture_utils.py create mode 100644 tests/features/iocommands.feature create mode 100644 tests/features/named_queries.feature create mode 100644 tests/features/pgbouncer.feature create mode 100644 tests/features/specials.feature create mode 100644 tests/features/steps/__init__.py create mode 100644 tests/features/steps/auto_vertical.py create mode 100644 tests/features/steps/basic_commands.py create mode 100644 tests/features/steps/crud_database.py create mode 100644 tests/features/steps/crud_table.py create mode 100644 tests/features/steps/expanded.py create mode 100644 tests/features/steps/iocommands.py create mode 100644 tests/features/steps/named_queries.py create mode 100644 tests/features/steps/pgbouncer.py create mode 100644 tests/features/steps/specials.py create mode 100644 tests/features/steps/wrappers.py create mode 100644 tests/features/wrappager.py create mode 100644 tests/formatter/__init__.py create mode 100644 tests/formatter/test_sqlformatter.py create mode 100644 tests/metadata.py create mode 100644 tests/parseutils/test_ctes.py create mode 100644 tests/parseutils/test_function_metadata.py create mode 100644 tests/parseutils/test_parseutils.py create mode 100644 tests/pytest.ini create mode 100644 tests/test_auth.py create mode 100644 tests/test_completion_refresher.py create mode 100644 tests/test_config.py create mode 100644 tests/test_fuzzy_completion.py create mode 100644 tests/test_main.py create mode 100644 tests/test_naive_completion.py create mode 100644 tests/test_pgcompleter.py create mode 100644 tests/test_pgexecute.py create mode 100644 tests/test_pgspecial.py create mode 100644 tests/test_prioritization.py create mode 100644 tests/test_prompt_utils.py create mode 100644 tests/test_rowlimit.py create mode 100644 tests/test_smart_completion_multiple_schemata.py create mode 100644 tests/test_smart_completion_public_schema_only.py create mode 100644 tests/test_sqlcompletion.py create mode 100644 tests/test_ssh_tunnel.py create mode 100644 tests/utils.py create mode 100644 tox.ini diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b2713c7 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +parallel=True +source=pgcli diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..bacb65c --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +# editorconfig.org +# Get your text editor plugin at: +# http://editorconfig.org/#download +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = true +trim_trailing_whitespace = true + +[travis.yml] +indent_size = 2 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..e69de29 diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..b5cdbec --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,9 @@ +## Description + + +## Your environment + + +- [ ] Please provide your OS and version information. +- [ ] Please provide your CLI version. +- [ ] What is the output of ``pip freeze`` command. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..35e8486 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,12 @@ +## Description + + + + +## Checklist + +- [ ] I've added this contribution to the `changelog.rst`. +- [ ] I've added my name to the `AUTHORS` file (or it's already there). + +- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code. +- [x] Please squash merge this pull request (uncheck if you'd like us to merge as multiple commits) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..68a69ac --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,102 @@ +name: pgcli + +on: + push: + branches: + - main + pull_request: + paths-ignore: + - '**.rst' + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + + services: + postgres: + image: postgres:9.6 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install pgbouncer + run: | + sudo apt install pgbouncer -y + + sudo chmod 666 /etc/pgbouncer/*.* + + cat < /etc/pgbouncer/userlist.txt + "postgres" "postgres" + EOF + + cat < /etc/pgbouncer/pgbouncer.ini + [databases] + * = host=localhost port=5432 + [pgbouncer] + listen_port = 6432 + listen_addr = localhost + auth_type = trust + auth_file = /etc/pgbouncer/userlist.txt + logfile = pgbouncer.log + pidfile = pgbouncer.pid + admin_users = postgres + EOF + + sudo systemctl stop pgbouncer + + pgbouncer -d /etc/pgbouncer/pgbouncer.ini + + psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help' + + - name: Install beta version of pendulum + run: pip install pendulum==3.0.0b1 + if: matrix.python-version == '3.12' + + - name: Install requirements + run: | + pip install -U pip setuptools + pip install --no-cache-dir ".[sshtunnel]" + pip install -r requirements-dev.txt + pip install keyrings.alt>=3.1 + + - name: Run unit tests + run: coverage run --source pgcli -m pytest + + - name: Run integration tests + env: + PGUSER: postgres + PGPASSWORD: postgres + + run: behave tests/features --no-capture + + - name: Check changelog for ReST compliance + run: rst2html.py --halt=warning changelog.rst >/dev/null + + - name: Run Black + run: black --check . + if: matrix.python-version == '3.8' + + - name: Coverage + run: | + coverage combine + coverage report + codecov diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..c9232c7 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,41 @@ +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: "29 13 * * 1" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ python ] + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + queries: +security-and-quality + + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{ matrix.language }}" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b993cb9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,74 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +pyvenv/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +.pytest_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# PyCharm +.idea/ +*.iml + +# Vagrant +.vagrant/ + +# Generated Packages +*.deb +*.rpm + +.vscode/ +venv/ + +.ropeproject/ + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8462cc2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,5 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..5eff7db --- /dev/null +++ b/AUTHORS @@ -0,0 +1,136 @@ +Many thanks to the following contributors. + +Project Lead: +------------- + * Irina Truong + +Core Devs: +---------- + * Amjith Ramanujam + * Darik Gamble + * Stuart Quin + * Joakim Koljonen + * Daniel Rocco + * Karl-Aksel Puulmann + * Dick Marinus + +Contributors: +------------- + * Brett + * Étienne BERSAC (bersace) + * Daniel Schwarz + * inkn + * Jonathan Slenders + * xalley + * TamasNo1 + * François Pietka + * Michael Kaminsky + * Alexander Kukushkin + * Ludovic Gasc (GMLudo) + * Marc Abramowitz + * Nick Hahner + * Jay Zeng + * Dimitar Roustchev + * Dhaivat Pandit + * Matheus Rosa + * Ali Kargın + * Nathan Jhaveri + * David Celis + * Sven-Hendrik Haase + * Çağatay Yüksel + * Tiago Ribeiro + * Vignesh Anand + * Charlie Arnold + * dwalmsley + * Artur Dryomov + * rrampage + * while0pass + * Eric Workman + * xa + * Hans Roman + * Guewen Baconnier + * Dionysis Grigoropoulos + * Jacob Magnusson + * Johannes Hoff + * vinotheassassin + * Jacek Wielemborek + * Fabien Meghazi + * Manuel Barkhau + * Sergii V + * Emanuele Gaifas + * Owen Stephens + * Russell Davies + * AlexTes + * Hraban Luyat + * Jackson Popkin + * Gustavo Castro + * Alexander Schmolck + * Donnell Muse + * Andrew Speed + * Dmitry B + * Isank + * Marcin Sztolcman + * Bojan Delić + * Chris Vaughn + * Frederic Aoustin + * Pierre Giraud + * Andrew Kuchling + * Dan Clark + * Catherine Devlin + * Jason Ribeiro + * Rishi Ramraj + * Matthieu Guilbert + * Alexandr Korsak + * Saif Hakim + * Artur Balabanov + * Kenny Do + * Max Rothman + * Daniel Egger + * Ignacio Campabadal + * Mikhail Elovskikh (wronglink) + * Marcin Cieślak (saper) + * easteregg (verfriemelt-dot-org) + * Scott Brenstuhl (808sAndBR) + * Nathan Verzemnieks + * raylu + * Zhaolong Zhu + * Zane C. Bowers-Hadley + * Telmo "Trooper" (telmotrooper) + * Alexander Zawadzki + * Pablo A. Bianchi (pabloab) + * Sebastian Janko (sebojanko) + * Pedro Ferrari (petobens) + * Martin Matejek (mmtj) + * Jonas Jelten + * BrownShibaDog + * George Thomas(thegeorgeous) + * Yoni Nakache(lazydba247) + * Gantsev Denis + * Stephano Paraskeva + * Panos Mavrogiorgos (pmav99) + * Igor Kim (igorkim) + * Anthony DeBarros (anthonydb) + * Seungyong Kwak (GUIEEN) + * Tom Caruso (tomplex) + * Jan Brun Rasmussen (janbrunrasmussen) + * Kevin Marsh (kevinmarsh) + * Eero Ruohola (ruohola) + * Miroslav Šedivý (eumiro) + * Eric R Young (ERYoung11) + * Paweł Sacawa (psacawa) + * Bruno Inec (sweenu) + * Daniele Varrazzo + * Daniel Kukula (dkuku) + * Kian-Meng Ang (kianmeng) + * Liu Zhao (astroshot) + * Rigo Neri (rigoneri) + * Anna Glasgall (annathyst) + * Andy Schoenberger (andyscho) + * Damien Baty (dbaty) + * blag + * Rob Berry (rob-b) + * Sharon Yogev (sharonyogev) + +Creator: +-------- +Amjith Ramanujam diff --git a/DEVELOP.rst b/DEVELOP.rst new file mode 100644 index 0000000..aed2cf8 --- /dev/null +++ b/DEVELOP.rst @@ -0,0 +1,220 @@ +Development Guide +----------------- +This is a guide for developers who would like to contribute to this project. + +GitHub Workflow +--------------- + +If you're interested in contributing to pgcli, first of all my heart felt +thanks. `Fork the project `_ on github. Then +clone your fork into your computer (``git clone ``). Make +the changes and create the commits in your local machine. Then push those +changes to your fork. Then click on the pull request icon on github and create +a new pull request. Add a description about the change and send it along. I +promise to review the pull request in a reasonable window of time and get back +to you. + +In order to keep your fork up to date with any changes from mainline, add a new +git remote to your local copy called 'upstream' and point it to the main pgcli +repo. + +:: + + $ git remote add upstream git@github.com:dbcli/pgcli.git + +Once the 'upstream' end point is added you can then periodically do a ``git +pull upstream master`` to update your local copy and then do a ``git push +origin master`` to keep your own fork up to date. + +Check Github's `Understanding the GitHub flow guide +`_ for a more detailed +explanation of this process. + +Local Setup +----------- + +The installation instructions in the README file are intended for users of +pgcli. If you're developing pgcli, you'll need to install it in a slightly +different way so you can see the effects of your changes right away without +having to go through the install cycle every time you change the code. + +It is highly recommended to use virtualenv for development. If you don't know +what a virtualenv is, `this guide `_ +will help you get started. + +Create a virtualenv (let's call it pgcli-dev). Activate it: + +:: + + source ./pgcli-dev/bin/activate + + or + + .\pgcli-dev\scripts\activate (for Windows) + +Once the virtualenv is activated, `cd` into the local clone of pgcli folder +and install pgcli using pip as follows: + +:: + + $ pip install --editable . + + or + + $ pip install -e . + +This will install the necessary dependencies as well as install pgcli from the +working folder into the virtualenv. By installing it using `pip install -e` +we've linked the pgcli installation with the working copy. Any changes made +to the code are immediately available in the installed version of pgcli. This +makes it easy to change something in the code, launch pgcli and check the +effects of your changes. + +Adding PostgreSQL Special (Meta) Commands +----------------------------------------- + +If you want to work on adding new meta-commands (such as `\dp`, `\ds`, `dy`), +you need to contribute to `pgspecial `_ +project. + +Visual Studio Code Debugging +----------------------------- +To set up Visual Studio Code to debug pgcli requires a launch.json file. + +Within the project, create a file: .vscode\\launch.json like below. + +:: + + { + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Module", + "type": "python", + "request": "launch", + "module": "pgcli.main", + "justMyCode": false, + "console": "externalTerminal", + "env": { + "PGUSER": "postgres", + "PGPASS": "password", + "PGHOST": "localhost", + "PGPORT": "5432" + } + } + ] + } + +Building RPM and DEB packages +----------------------------- + +You will need Vagrant 1.7.2 or higher. In the project root there is a +Vagrantfile that is setup to do multi-vm provisioning. If you're setting things +up for the first time, then do: + +:: + + $ version=x.y.z vagrant up debian + $ version=x.y.z vagrant up centos + +If you already have those VMs setup and you're merely creating a new version of +DEB or RPM package, then you can do: + +:: + + $ version=x.y.z vagrant provision + +That will create a .deb file and a .rpm file. + +The deb package can be installed as follows: + +:: + + $ sudo dpkg -i pgcli*.deb # if dependencies are available. + + or + + $ sudo apt-get install -f pgcli*.deb # if dependencies are not available. + + +The rpm package can be installed as follows: + +:: + + $ sudo yum install pgcli*.rpm + +Running the integration tests +----------------------------- + +Integration tests use `behave package `_ and +pytest. +Configuration settings for this package are provided via a ``behave.ini`` file +in the ``tests`` directory. An example:: + + [behave] + stderr_capture = false + + [behave.userdata] + pg_test_user = dbuser + pg_test_host = db.example.com + pg_test_port = 30000 + +First, install the requirements for testing: + +:: + $ pip install -U pip setuptools + $ pip install --no-cache-dir ".[sshtunnel]" + $ pip install -r requirements-dev.txt + +Ensure that the database user has permissions to create and drop test databases +by checking your ``pg_hba.conf`` file. The default user should be ``postgres`` +at ``localhost``. Make sure the authentication method is set to ``trust``. If +you made any changes to your ``pg_hba.conf`` make sure to restart the postgres +service for the changes to take effect. + +:: + + # ONLY IF YOU MADE CHANGES TO YOUR pg_hba.conf FILE + $ sudo service postgresql restart + +After that, tests in the ``/pgcli/tests`` directory can be run with: +(Note that these ``behave`` tests do not currently work when developing on Windows due to pexpect incompatibility.) + +:: + + # on directory /pgcli/tests + $ behave + +And on the ``/pgcli`` directory: + +:: + + # on directory /pgcli + $ py.test + +To see stdout/stderr, use the following command: + +:: + + $ behave --no-capture + +Troubleshooting the integration tests +------------------------------------- + +- Make sure postgres instance on localhost is running +- Check your ``pg_hba.conf`` file to verify local connections are enabled +- Check `this issue `_ for relevant information. +- `File an issue `_. + +Coding Style +------------ + +``pgcli`` uses `black `_ to format the source code. Make sure to install black. + +Releases +-------- + +If you're the person responsible for releasing `pgcli`, `this guide `_ is for you. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..32d341a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM python:3.8 + +COPY . /app +RUN cd /app && pip install -e . + +CMD pgcli diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..83226b7 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,26 @@ +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..1c5f697 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include LICENSE.txt AUTHORS changelog.rst +recursive-include tests *.py *.txt *.feature *.ini diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..56695cc --- /dev/null +++ b/README.rst @@ -0,0 +1,392 @@ +We stand with Ukraine +--------------------- + +Ukrainian people are fighting for their country. A lot of civilians, women and children, are suffering. Hundreds were killed and injured, and thousands were displaced. + +This is an image from my home town, Kharkiv. This place is right in the old city center. + +.. image:: screenshots/kharkiv-destroyed.jpg + +Picture by @fomenko_ph (Telegram). + +Please consider donating or volunteering. + +* https://bank.gov.ua/en/ +* https://savelife.in.ua/en/donate/ +* https://www.comebackalive.in.ua/donate +* https://www.globalgiving.org/projects/ukraine-crisis-relief-fund/ +* https://www.savethechildren.org/us/where-we-work/ukraine +* https://www.facebook.com/donate/1137971146948461/ +* https://donate.wck.org/give/393234#!/donation/checkout +* https://atlantaforukraine.com/ + + +A REPL for Postgres +------------------- + +|Build Status| |CodeCov| |PyPI| |netlify| + +This is a postgres client that does auto-completion and syntax highlighting. + +Home Page: http://pgcli.com + +MySQL Equivalent: http://mycli.net + +.. image:: screenshots/pgcli.gif +.. image:: screenshots/image01.png + +Quick Start +----------- + +If you already know how to install python packages, then you can simply do: + +:: + + $ pip install -U pgcli + + or + + $ sudo apt-get install pgcli # Only on Debian based Linux (e.g. Ubuntu, Mint, etc) + $ brew install pgcli # Only on macOS + +If you don't know how to install python packages, please check the +`detailed instructions`_. + +.. _`detailed instructions`: https://github.com/dbcli/pgcli#detailed-installation-instructions + +Usage +----- + +:: + + $ pgcli [database_name] + + or + + $ pgcli postgresql://[user[:password]@][netloc][:port][/dbname][?extra=value[&other=other-value]] + +Examples: + +:: + + $ pgcli local_database + + $ pgcli postgres://amjith:pa$$w0rd@example.com:5432/app_db?sslmode=verify-ca&sslrootcert=/myrootcert + +For more details: + +:: + + $ pgcli --help + + Usage: pgcli [OPTIONS] [DBNAME] [USERNAME] + + Options: + -h, --host TEXT Host address of the postgres database. + -p, --port INTEGER Port number at which the postgres instance is + listening. + -U, --username TEXT Username to connect to the postgres database. + -u, --user TEXT Username to connect to the postgres database. + -W, --password Force password prompt. + -w, --no-password Never prompt for password. + --single-connection Do not use a separate connection for completions. + -v, --version Version of pgcli. + -d, --dbname TEXT database name to connect to. + --pgclirc FILE Location of pgclirc file. + -D, --dsn TEXT Use DSN configured into the [alias_dsn] section + of pgclirc file. + --list-dsn list of DSN configured into the [alias_dsn] + section of pgclirc file. + --row-limit INTEGER Set threshold for row limit prompt. Use 0 to + disable prompt. + --less-chatty Skip intro on startup and goodbye on exit. + --prompt TEXT Prompt format (Default: "\u@\h:\d> "). + --prompt-dsn TEXT Prompt format for connections using DSN aliases + (Default: "\u@\h:\d> "). + -l, --list list available databases, then exit. + --auto-vertical-output Automatically switch to vertical output mode if + the result is wider than the terminal width. + --warn [all|moderate|off] Warn before running a destructive query. + --help Show this message and exit. + +``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``). + +The SSL-related environment variables are also supported, so if you need to connect a postgres database via ssl connection, you can set set environment like this: + +:: + + export PGSSLMODE="verify-full" + export PGSSLCERT="/your-path-to-certs/client.crt" + export PGSSLKEY="/your-path-to-keys/client.key" + export PGSSLROOTCERT="/your-path-to-ca/ca.crt" + pgcli -h localhost -p 5432 -U username postgres + +.. _environment variables: https://www.postgresql.org/docs/current/libpq-envars.html + +Features +-------- + +The `pgcli` is written using prompt_toolkit_. + +* Auto-completes as you type for SQL keywords as well as tables and + columns in the database. +* Syntax highlighting using Pygments. +* Smart-completion (enabled by default) will suggest context-sensitive + completion. + + - ``SELECT * FROM `` will only show table names. + - ``SELECT * FROM users WHERE `` will only show column names. + +* Primitive support for ``psql`` back-slash commands. +* Pretty prints tabular data. + +.. _prompt_toolkit: https://github.com/jonathanslenders/python-prompt-toolkit +.. _tabulate: https://pypi.python.org/pypi/tabulate + +Config +------ +A config file is automatically created at ``~/.config/pgcli/config`` at first launch. +See the file itself for a description of all available options. + +Contributions: +-------------- + +If you're interested in contributing to this project, first of all I would like +to extend my heartfelt gratitude. I've written a small doc to describe how to +get this running in a development setup. + +https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst + +Please feel free to reach out to us if you need help. +* Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr `_ +* Irina, pgcli maintainer: i.chernyavska@gmail.com, Twitter: `@irinatruong `_ + +Detailed Installation Instructions: +----------------------------------- + +macOS: +====== + +The easiest way to install pgcli is using Homebrew. + +:: + + $ brew install pgcli + +Done! + +Alternatively, you can install ``pgcli`` as a python package using a package +manager called called ``pip``. You will need postgres installed on your system +for this to work. + +In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html. + +:: + + $ which pip + +If it is installed then you can do: + +:: + + $ pip install pgcli + +If that fails due to permission issues, you might need to run the command with +sudo permissions. + +:: + + $ sudo pip install pgcli + +If pip is not installed check if easy_install is available on the system. + +:: + + $ which easy_install + + $ sudo easy_install pgcli + +Linux: +====== + +In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html. + +Check if pip is already available in your system. + +:: + + $ which pip + +If it doesn't exist, use your linux package manager to install `pip`. This +might look something like: + +:: + + $ sudo apt-get install python-pip # Debian, Ubuntu, Mint etc + + or + + $ sudo yum install python-pip # RHEL, Centos, Fedora etc + +``pgcli`` requires python-dev, libpq-dev and libevent-dev packages. You can +install these via your operating system package manager. + + +:: + + $ sudo apt-get install python-dev libpq-dev libevent-dev + + or + + $ sudo yum install python-devel postgresql-devel + +Then you can install pgcli: + +:: + + $ sudo pip install pgcli + + +Docker +====== + +Pgcli can be run from within Docker. This can be useful to try pgcli without +installing it, or any dependencies, system-wide. + +To build the image: + +:: + + $ docker build -t pgcli . + +To create a container from the image: + +:: + + $ docker run --rm -ti pgcli pgcli + +To access postgresql databases listening on localhost, make sure to run the +docker in "host net mode". E.g. to access a database called "foo" on the +postgresql server running on localhost:5432 (the standard port): + +:: + + $ docker run --rm -ti --net host pgcli pgcli -h localhost foo + +To connect to a locally running instance over a unix socket, bind the socket to +the docker container: + +:: + + $ docker run --rm -ti -v /var/run/postgres:/var/run/postgres pgcli pgcli foo + + +IPython +======= + +Pgcli can be run from within `IPython `_ console. When working on a query, +it may be useful to drop into a pgcli session without leaving the IPython console, iterate on a +query, then quit pgcli to find the query results in your IPython workspace. + +Assuming you have IPython installed: + +:: + + $ pip install ipython-sql + +After that, run ipython and load the ``pgcli.magic`` extension: + +:: + + $ ipython + + In [1]: %load_ext pgcli.magic + + +Connect to a database and construct a query: + +:: + + In [2]: %pgcli postgres://someone@localhost:5432/world + Connected: someone@world + someone@localhost:world> select * from city c where countrycode = 'USA' and population > 1000000; + +------+--------------+---------------+--------------+--------------+ + | id | name | countrycode | district | population | + |------+--------------+---------------+--------------+--------------| + | 3793 | New York | USA | New York | 8008278 | + | 3794 | Los Angeles | USA | California | 3694820 | + | 3795 | Chicago | USA | Illinois | 2896016 | + | 3796 | Houston | USA | Texas | 1953631 | + | 3797 | Philadelphia | USA | Pennsylvania | 1517550 | + | 3798 | Phoenix | USA | Arizona | 1321045 | + | 3799 | San Diego | USA | California | 1223400 | + | 3800 | Dallas | USA | Texas | 1188580 | + | 3801 | San Antonio | USA | Texas | 1144646 | + +------+--------------+---------------+--------------+--------------+ + SELECT 9 + Time: 0.003s + + +Exit out of pgcli session with ``Ctrl + D`` and find the query results: + +:: + + someone@localhost:world> + Goodbye! + 9 rows affected. + Out[2]: + [(3793, u'New York', u'USA', u'New York', 8008278), + (3794, u'Los Angeles', u'USA', u'California', 3694820), + (3795, u'Chicago', u'USA', u'Illinois', 2896016), + (3796, u'Houston', u'USA', u'Texas', 1953631), + (3797, u'Philadelphia', u'USA', u'Pennsylvania', 1517550), + (3798, u'Phoenix', u'USA', u'Arizona', 1321045), + (3799, u'San Diego', u'USA', u'California', 1223400), + (3800, u'Dallas', u'USA', u'Texas', 1188580), + (3801, u'San Antonio', u'USA', u'Texas', 1144646)] + +The results are available in special local variable ``_``, and can be assigned to a variable of your +choice: + +:: + + In [3]: my_result = _ + +Pgcli dropped support for Python<3.8 as of 4.0.0. If you need it, install ``pgcli <= 4.0.0``. + +Thanks: +------- + +A special thanks to `Jonathan Slenders `_ for +creating `Python Prompt Toolkit `_, +which is quite literally the backbone library, that made this app possible. +Jonathan has also provided valuable feedback and support during the development +of this app. + +`Click `_ is used for command line option parsing +and printing error messages. + +Thanks to `psycopg `_ for providing a rock solid +interface to Postgres database. + +Thanks to all the beta testers and contributors for your time and patience. :) + + +.. |Build Status| image:: https://github.com/dbcli/pgcli/actions/workflows/ci.yml/badge.svg?branch=main + :target: https://github.com/dbcli/pgcli/actions/workflows/ci.yml + +.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg + :target: https://codecov.io/gh/dbcli/pgcli + :alt: Code coverage report + +.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/master/landscape.svg?style=flat + :target: https://landscape.io/github/dbcli/pgcli/master + :alt: Code Health + +.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg + :target: https://pypi.python.org/pypi/pgcli/ + :alt: Latest Version + +.. |netlify| image:: https://api.netlify.com/api/v1/badges/3a0a14dd-776d-445d-804c-3dd74fe31c4e/deploy-status + :target: https://app.netlify.com/sites/pgcli/deploys + :alt: Netlify diff --git a/RELEASES.md b/RELEASES.md new file mode 100644 index 0000000..526c260 --- /dev/null +++ b/RELEASES.md @@ -0,0 +1,24 @@ +Releasing pgcli +--------------- + +You have been made the maintainer of `pgcli`? Congratulations! We have a release script to help you: + +```sh +> python release.py --help +Usage: release.py [options] + +Options: + -h, --help show this help message and exit + -c, --confirm-steps Confirm every step. If the step is not confirmed, it + will be skipped. + -d, --dry-run Print out, but not actually run any steps. +``` + +The script can be run with `-c` to confirm or skip steps. There's also a `--dry-run` option that only prints out the steps. + +To release a new version of the package: + +* Create and merge a PR to bump the version in the changelog ([example PR](https://github.com/dbcli/pgcli/pull/1325)). +* Pull `main` and bump the version number inside `pgcli/__init__.py`. Do not check in - the release script will do that. +* Make sure you have the dev requirements installed: `pip install -r requirements-dev.txt -U --upgrade-strategy only-if-needed`. +* Finally, run the release script: `python release.py`. diff --git a/TODO b/TODO new file mode 100644 index 0000000..2173545 --- /dev/null +++ b/TODO @@ -0,0 +1,12 @@ +# vi: ft=vimwiki +* [ ] Add coverage. +* [ ] Refactor to sqlcompletion to consume the text from left to right and use a state machine to suggest cols or tables instead of relying on hacks. +* [ ] Add a few more special commands. (\l pattern, \dp, \ds, \dy, \z etc) +* [ ] Refactor pgspecial.py to a class. +* [ ] Show/hide docs for a statement using a keybinding. +* [ ] Check how to add the name of the table before printing the table. +* [ ] Add a new trigger for M-/ that does naive completion. +* [ ] New Feature List - Write the current version to config file. At launch if the version has changed, display the changelog between the two versions. +* [ ] Add a test for 'select * from custom.abc where custom.abc.' should suggest columns from abc. +* [ ] pgexecute columns(), tables() etc can be just cursors instead of fetchall() +* [ ] Add colorschemes in config file. diff --git a/Vagrantfile b/Vagrantfile new file mode 100644 index 0000000..297e70a --- /dev/null +++ b/Vagrantfile @@ -0,0 +1,138 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +# +# + +Vagrant.configure(2) do |config| + + config.vm.synced_folder ".", "/pgcli" + + pgcli_version = ENV['version'] + pgcli_description = "Postgres CLI with autocompletion and syntax highlighting" + + config.vm.define "debian" do |debian| + debian.vm.box = "bento/debian-10.8" + debian.vm.provision "shell", inline: <<-SHELL + echo "-> Building DEB on `lsb_release -d`" + sudo apt-get update + sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems + sudo apt install -y python3-pip + sudo pip3 install --no-cache-dir virtualenv virtualenv-tools3 + sudo apt-get install -y ruby-dev + sudo apt-get install -y git + sudo apt-get install -y rpm librpmbuild8 + + sudo gem install fpm + + echo "-> Cleaning up old workspace" + sudo rm -rf build + mkdir -p build/usr/share + virtualenv build/usr/share/pgcli + build/usr/share/pgcli/bin/pip install /pgcli + + echo "-> Cleaning Virtualenv" + cd build/usr/share/pgcli + virtualenv-tools --update-path /usr/share/pgcli > /dev/null + cd /home/vagrant/ + + echo "-> Removing compiled files" + find build -iname '*.pyc' -delete + find build -iname '*.pyo' -delete + + echo "-> Creating PgCLI deb" + sudo fpm -t deb -s dir -C build -n pgcli -v #{pgcli_version} \ + -a all \ + -d libpq-dev \ + -d python-dev \ + -p /pgcli/ \ + --after-install /pgcli/post-install \ + --after-remove /pgcli/post-remove \ + --url https://github.com/dbcli/pgcli \ + --description "#{pgcli_description}" \ + --license 'BSD' + + SHELL + end + + +# This is considerably more messy than the debian section. I had to go off-standard to update +# some packages to get this to work. + + config.vm.define "centos" do |centos| + + centos.vm.box = "bento/centos-7.9" + centos.vm.box_version = "202012.21.0" + centos.vm.provision "shell", inline: <<-SHELL + #!/bin/bash + echo "-> Building RPM on `hostnamectl | grep "Operating System"`" + export PATH=/usr/local/rvm/gems/ruby-2.6.3/bin:/usr/local/rvm/gems/ruby-2.6.3@global/bin:/usr/local/rvm/rubies/ruby-2.6.3/bin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/rvm/bin:/root/bin + echo "PATH -> " $PATH + +##### +### get base updates + + sudo yum install -y rpm-build gcc postgresql-devel python-devel python3-pip git python3-devel + +###### +### install FPM, which we need to install to get an up-to-date version of ruby, which we need for git + + echo "-> Get FPM installed" + # import the necessary GPG keys + gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB + sudo gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB + # install RVM + sudo curl -sSL https://get.rvm.io | sudo bash -s stable + sudo usermod -aG rvm vagrant + sudo usermod -aG rvm root + sudo /usr/local/rvm/bin/rvm alias create default 2.6.3 + source /etc/profile.d/rvm.sh + + # install a newer version of ruby. centos7 only comes with ruby2.0.0, which isn't good enough for git. + sudo yum install -y ruby-devel + sudo /usr/local/rvm/bin/rvm install 2.6.3 + + # + # yes,this gives an error about generating doc but we don't need the doc. + + /usr/local/rvm/gems/ruby-2.6.3/wrappers/gem install fpm + +###### + + sudo pip3 install virtualenv virtualenv-tools3 + echo "-> Cleaning up old workspace" + rm -rf build + mkdir -p build/usr/share + virtualenv build/usr/share/pgcli + build/usr/share/pgcli/bin/pip install /pgcli + + echo "-> Cleaning Virtualenv" + cd build/usr/share/pgcli + virtualenv-tools --update-path /usr/share/pgcli > /dev/null + cd /home/vagrant/ + + echo "-> Removing compiled files" + find build -iname '*.pyc' -delete + find build -iname '*.pyo' -delete + + cd /home/vagrant + echo "-> Creating PgCLI RPM" + /usr/local/rvm/gems/ruby-2.6.3/gems/fpm-1.12.0/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \ + -a all \ + -d postgresql-devel \ + -d python-devel \ + -p /pgcli/ \ + --after-install /pgcli/post-install \ + --after-remove /pgcli/post-remove \ + --url https://github.com/dbcli/pgcli \ + --description "#{pgcli_description}" \ + --license 'BSD' + + + SHELL + + + end + + +end + diff --git a/changelog.rst b/changelog.rst new file mode 100644 index 0000000..7d08839 --- /dev/null +++ b/changelog.rst @@ -0,0 +1,1203 @@ +================== +4.0.1 (2023-11-30) +================== + +Internal: +--------- +* Allow stable version of pendulum. + +================== +4.0.0 (2023-11-27) +================== + +Features: +--------- + +* Ask for confirmation when quitting cli while a transaction is ongoing. +* New `destructive_statements_require_transaction` config option to refuse to execute a + destructive SQL statement if outside a transaction. This option is off by default. +* Changed the `destructive_warning` config to be a list of commands that are considered + destructive. This would allow you to be warned on `create`, `grant`, or `insert` queries. +* Destructive warnings will now include the alias dsn connection string name if provided (-D option). +* pgcli.magic will now work with connection URLs that use TLS client certificates for authentication +* Have config option to retry queries on operational errors like connections being lost. + Also prevents getting stuck in a retry loop. +* Config option to not restart connection when cancelling a `destructive_warning` query. By default, + it will now not restart. +* Config option to always run with a single connection. +* Add comment explaining default LESS environment variable behavior and change example pager setting. +* Added `\echo` & `\qecho` special commands. ([issue 1335](https://github.com/dbcli/pgcli/issues/1335)). + +Bug fixes: +---------- + +* Fix `\ev` not producing a correctly quoted "schema"."view" +* Fix 'invalid connection option "dsn"' ([issue 1373](https://github.com/dbcli/pgcli/issues/1373)). +* Fix explain mode when used with `expand`, `auto_expand`, or `--explain-vertical-output` ([issue 1393](https://github.com/dbcli/pgcli/issues/1393)). +* Fix sql-insert format emits NULL as 'None' ([issue 1408](https://github.com/dbcli/pgcli/issues/1408)). +* Improve check for prompt-toolkit 3.0.6 ([issue 1416](https://github.com/dbcli/pgcli/issues/1416)). +* Allow specifying an `alias_map_file` in the config that will use + predetermined table aliases instead of generating aliases programmatically on + the fly +* Fixed SQL error when there is a comment on the first line: ([issue 1403](https://github.com/dbcli/pgcli/issues/1403)) +* Fix wrong usage of prompt instead of confirm when confirm execution of destructive query + +Internal: +--------- + +* Drop support for Python 3.7 and add 3.12. + +3.5.0 (2022/09/15): +=================== + +Features: +--------- + +* New formatter is added to export query result to sql format (such as sql-insert, sql-update) like mycli. + +Bug fixes: +---------- + +* Fix exception when retrieving password from keyring ([issue 1338](https://github.com/dbcli/pgcli/issues/1338)). +* Fix using comments with special commands ([issue 1362](https://github.com/dbcli/pgcli/issues/1362)). +* Small improvements to the Windows developer experience +* Fix submitting queries in safe multiline mode ([1360](https://github.com/dbcli/pgcli/issues/1360)). + +Internal: +--------- + +* Port to psycopg3 (https://github.com/psycopg/psycopg). +* Fix typos + +3.4.1 (2022/03/19) +================== + +Bug fixes: +---------- + +* Fix the bug with Redshift not displaying word count in status ([related issue](https://github.com/dbcli/pgcli/issues/1320)). +* Show the error status for CSV output format. + + +3.4.0 (2022/02/21) +================== + +Features: +--------- + +* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)). + +3.3.1 (2022/01/18) +================== + +Bug fixes: +---------- + +* Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307. +* Upgrade cli_helpers to 2.2.1 + +3.3.0 (2022/01/11) +================== + +Features: +--------- + +* Add `max_field_width` setting to config, to enable more control over field truncation ([related issue](https://github.com/dbcli/pgcli/issues/1250)). +* Re-run last query via bare `\watch`. (Thanks: `Saif Hakim`_) + +Bug fixes: +---------- + +* Pin the version of pygments to prevent breaking change + +3.2.0 +===== + +Release date: 2021/08/23 + +Features: +--------- + +* Consider `update` queries destructive and issue a warning. Change + `destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239) +* Skip initial comment in .pg_session even if it doesn't start with '#' +* Include functions from schemas in search_path. (`Amjith Ramanujam`_) +* Easy way to show explain output under F5 + +Bug fixes: +---------- + +* Fix issue where `syntax_style` config value would not have any effect. (#1212) +* Fix crash because of not found `InputMode.REPLACE_SINGLE` with prompt-toolkit < 3.0.6 +* Fix comments being lost in config when saving a named query. (#1240) +* Fix IPython magic for ipython-sql >= 0.4.0 +* Fix pager not being used when output format is set to csv. (#1238) +* Add function literals random, generate_series, generate_subscripts +* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly +* Fix pgcli crashing with virtual `pgbouncer` database. (#1093) + +3.1.0 +===== + +Features: +--------- + +* Make the output more compact by removing the empty newline. (Thanks: `laixintao`_) +* Add support for using [pspg](https://github.com/okbob/pspg) as a pager (#1102) +* Update python version in Dockerfile +* Support setting color for null, string, number, keyword value +* Support Prompt Toolkit 2 +* Support sqlparse 0.4.x +* Update functions, datatypes literals for auto-suggestion field +* Add suggestion for schema in function auto-complete + +Bug fixes: +---------- + +* Minor typo fixes in `pgclirc`. (Thanks: `anthonydb`_) +* Fix for list index out of range when executing commands from a file (#1193). (Thanks: `Irina Truong`_) +* Move from `humanize` to `pendulum` for displaying query durations (#1015) +* More explicit error message when connecting using DSN alias and it is not found. + +3.0.0 +===== + +Features: +--------- + +* Add `__main__.py` file to execute pgcli as a package directly (#1123). +* Add support for ANSI escape sequences for coloring the prompt (#1122). +* Add support for partitioned tables (relkind "p"). +* Add support for `pg_service.conf` files +* Add config option show_bottom_toolbar. + +Bug fixes: +---------- + +* Fix warning raised for using `is not` to compare string literal +* Close open connection in completion_refresher thread + +Internal: +--------- + +* Drop Python2.7, 3.4, 3.5 support. (Thanks: `laixintao`_) +* Support Python3.8. (Thanks: `laixintao`_) +* Fix dead link in development guide. (Thanks: `BrownShibaDog`_) +* Upgrade python-prompt-toolkit to v3.0. (Thanks: `laixintao`_) + + +2.2.0: +====== + +Features: +--------- + +* Add `\\G` as a terminator to sql statements that will show the results in expanded mode. This feature is copied from mycli. (Thanks: `Amjith Ramanujam`_) +* Removed limit prompt and added automatic row limit on queries with no LIMIT clause (#1079) (Thanks: `Sebastian Janko`_) +* Function argument completions now take account of table aliases (#1048). (Thanks: `Owen Stephens`_) + +Bug fixes: +---------- + +* Error connecting to PostgreSQL 12beta1 (#1058). (Thanks: `Irina Truong`_ and `Amjith Ramanujam`_) +* Empty query caused error message (#1019) (Thanks: `Sebastian Janko`_) +* History navigation bindings in multiline queries (#1004) (Thanks: `Pedro Ferrari`_) +* Can't connect to pgbouncer database (#1093). (Thanks: `Irina Truong`_) +* Fix broken multi-line history search (#1031). (Thanks: `Owen Stephens`_) +* Fix slow typing/movement when multi-line query ends in a semicolon (#994). (Thanks: `Owen Stephens`_) +* Fix for PQconninfo not available in libpq < 9.3 (#1110). (Thanks: `Irina Truong`_) + +Internal: +--------- + +* Add optional but default squash merge request to PULL_REQUEST_TEMPLATE + +2.1.1 +===== + +Bug fixes: +---------- +* Escape switches to VI navigation mode when not canceling completion popup. (Thanks: `Nathan Verzemnieks`_) +* Allow application_name to be overridden. (Thanks: `raylu`_) +* Fix for "no attribute KeyringLocked" (#1040). (Thanks: `Irina Truong`_) +* Pgcli no longer works with password containing spaces (#1043). (Thanks: `Irina Truong`_) +* Load keyring only when keyring is enabled in the config file (#1041). (Thanks: `Zhaolong Zhu`_) +* No longer depend on sqlparse as being less than 0.3.0 with the release of sqlparse 0.3.0. (Thanks: `VVelox`_) +* Fix the broken support for pgservice . (Thanks: `Xavier Francisco`_) +* Connecting using socket is broken in current master. (#1053). (Thanks: `Irina Truong`_) +* Allow usage of newer versions of psycopg2 (Thanks: `Telmo "Trooper"`_) +* Update README in alignment with the usage of newer versions of psycopg2 (Thanks: `Alexander Zawadzki`_) + +Internal: +--------- + +* Add python 3.7 to travis build matrix. (Thanks: `Irina Truong`_) +* Apply `black` to code. (Thanks: `Irina Truong`_) + +2.1.0 +===== + +Features: +--------- + +* Keybindings for closing the autocomplete list. (Thanks: `easteregg`_) +* Reconnect automatically when server closes connection. (Thanks: `Scott Brenstuhl`_) + +Bug fixes: +---------- +* Avoid error message on the server side if hstore extension is not installed in the current database (#991). (Thanks: `Marcin Cieślak`_) +* All pexpect submodules have been moved into the pexpect package as of version 3.0. Use pexpect.TIMEOUT (Thanks: `Marcin Cieślak`_) +* Resizing pgcli terminal kills the connection to postgres in python 2.7 (Thanks: `Amjith Ramanujam`_) +* Fix crash retrieving server version with ``--single-connection``. (Thanks: `Irina Truong`_) +* Cannot quit application without reconnecting to database (#1014). (Thanks: `Irina Truong`_) +* Password authentication failed for user "postgres" when using non-default password (#1020). (Thanks: `Irina Truong`_) + +Internal: +--------- + +* (Fixup) Clean up and add behave logging. (Thanks: `Marcin Cieślak`_, `Dick Marinus`_) +* Override VISUAL environment variable for behave tests. (Thanks: `Marcin Cieślak`_) +* Remove build dir before running sdist, remove stray files from wheel distribution. (Thanks: `Dick Marinus`_) +* Fix unit tests, unhashable formatted text since new python prompttoolkit version. (Thanks: `Dick Marinus`_) + +2.0.2: +====== + +Features: +--------- + +* Allows passing the ``-u`` flag to specify a username. (Thanks: `Ignacio Campabadal`_) +* Fix for lag in v2 (#979). (Thanks: `Irina Truong`_) +* Support for multihost connection string that is convenient if you have postgres cluster. (Thanks: `Mikhail Elovskikh`_) + +Internal: +--------- + +* Added tests for special command completion. (Thanks: `Amjith Ramanujam`_) + +2.0.1: +====== + +Bug fixes: +---------- + +* Tab press on an empty line increases the indentation instead of triggering + the auto-complete pop-up. (Thanks: `Artur Balabanov`_) +* Fix for loading/saving named queries from provided config file (#938). (Thanks: `Daniel Egger`_) +* Set default port in `connect_uri` when none is given. (Thanks: `Daniel Egger`_) +* Fix for error listing databases (#951). (Thanks: `Irina Truong`_) +* Enable Ctrl-Z to suspend the app (Thanks: `Amjith Ramanujam`_). +* Fix StopIteration exception raised at runtime for Python 3.7 (Thanks: `Amjith Ramanujam`_). + +Internal: +--------- + +* Clean up and add behave logging. (Thanks: `Dick Marinus`_) +* Require prompt_toolkit>=2.0.6. (Thanks: `Dick Marinus`_) +* Improve development guide. (Thanks: `Ignacio Campabadal`_) + +2.0.0: +====== + +* Update to ``prompt-toolkit`` 2.0. (Thanks: `Jonathan Slenders`_, `Dick Marinus`_, `Irina Truong`_) + +1.11.0 +====== + +Features: +--------- + +* Respect `\pset pager on` and use pager when output is longer than terminal height (Thanks: `Max Rothman`_) + +1.10.3 +====== + +Bug fixes: +---------- + +* Adapt the query used to get functions metadata to PG11 (#919). (Thanks: `Lele Gaifax`_). +* Fix for error retrieving version in Redshift (#922). (Thanks: `Irina Truong`_) +* Fix for keyring not disabled properly (#920). (Thanks: `Irina Truong`_) + +1.10.2 +====== + +Features: +--------- + +* Make `keyring` optional (Thanks: `Dick Marinus`_) + +1.10.1 +====== + +Bug fixes: +---------- + +* Fix for missing keyring. (Thanks: `Kenny Do`_) +* Fix for "-l" Flag Throws Error (#909). (Thanks: `Irina Truong`_) + +1.10.0 +====== + +Features: +--------- +* Add quit commands to the completion menu. (Thanks: `Jason Ribeiro`_) +* Add table formats to ``\T`` completion. (Thanks: `Jason Ribeiro`_) +* Support `\\ev``, ``\ef`` (#754). (Thanks: `Catherine Devlin`_) +* Add ``application_name`` to help identify pgcli connection to database (issue #868) (Thanks: `François Pietka`_) +* Add `--user` option, duplicate of `--username`, the same cli option like `psql` (Thanks: `Alexandr Korsak`_) + +Internal changes: +----------------- + +* Mark tests requiring a running database server as dbtest (Thanks: `Dick Marinus`_) +* Add an is_special command flag to MetaQuery (Thanks: `Rishi Ramraj`_) +* Ported Destructive Warning from mycli. +* Refactor Destructive Warning behave tests (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- +* Disable pager when using \watch (#837). (Thanks: `Jason Ribeiro`_) +* Don't offer to reconnect when we can't change a param in realtime (#807). (Thanks: `Amjith Ramanujam`_ and `Saif Hakim`_) +* Make keyring optional. (Thanks: `Dick Marinus`_) +* Fix ipython magic connection (#891). (Thanks: `Irina Truong`_) +* Fix not enough values to unpack. (Thanks: `Matthieu Guilbert`_) +* Fix unbound local error when destructive_warning is false. (Thanks: `Matthieu Guilbert`_) +* Render tab characters as 4 spaces instead of `^I`. (Thanks: `Artur Balabanov`_) + +1.9.1: +====== + +Features: +--------- + +* Change ``\h`` format string in prompt to only return the first part of the hostname, + up to the first '.' character. Add ``\H`` that returns the entire hostname (#858). + (Thanks: `Andrew Kuchling`_) +* Add Color of table by parameter. The color of table is function of syntax style + +Internal changes: +----------------- + +* Add tests, AUTHORS and changelog.rst to release. (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- +* Fix broken pgcli --list command line option (#850). (Thanks: `Dmitry B`_) + +1.9.0 +===== + +Features: +--------- + +* manage pager by \pset pager and add enable_pager to the config file (Thanks: `Frederic Aoustin`_). +* Add support for `\T` command to change format output. (Thanks: `Frederic Aoustin`_). +* Add option list-dsn (Thanks: `Frederic Aoustin`_). + + +Internal changes: +----------------- + +* Removed support for Python 3.3. (Thanks: `Irina Truong`_) + +1.8.2 +===== + +Features: +--------- + +* Use other prompt (prompt_dsn) when connecting using --dsn parameter. (Thanks: `Marcin Sztolcman`_) +* Include username into password prompt. (Thanks: `Bojan Delić`_) + +Internal changes: +----------------- +* Use temporary dir as config location in tests. (Thanks: `Dmitry B`_) +* Fix errors in the ``tee`` test (#795 and #797). (Thanks: `Irina Truong`_) +* Increase timeout for quitting pgcli. (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- +* Do NOT quote the database names in the completion menu (Thanks: `Amjith Ramanujam`_) +* Fix error in ``unix_socket_directories`` (#805). (Thanks: `Irina Truong`_) +* Fix the --list command line option tries to connect to 'personal' DB (#816). (Thanks: `Isank`_) + +1.8.1 +===== + +Internal changes: +----------------- +* Remove shebang and git execute permission from pgcli/main.py. (Thanks: `Dick Marinus`_) +* Require cli_helpers 0.2.3 (fix #791). (Thanks: `Dick Marinus`_) + +1.8.0 +===== + +Features: +--------- + +* Add fish-style auto-suggestion from history. (Thanks: `Amjith Ramanujam`_) +* Improved formatting of arrays in output (Thanks: `Joakim Koljonen`_) +* Don't quote identifiers that are non-reserved keywords. (Thanks: `Joakim Koljonen`_) +* Remove the ``...`` in the continuation prompt and use empty space instead. (Thanks: `Amjith Ramanujam`_) +* Add \conninfo and handle more parameters with \c (issue #716) (Thanks: `François Pietka`_) + +Internal changes: +----------------- +* Preliminary work for a future change in outputting results that uses less memory. (Thanks: `Dick Marinus`_) +* Remove import workaround for OrderedDict, required for python < 2.7. (Thanks: `Andrew Speed`_) +* Use less memory when formatting results for display (Thanks: `Dick Marinus`_). +* Port auto_vertical feature test from mycli to pgcli. (Thanks: `Dick Marinus`_) +* Drop wcwidth dependency (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- + +* Fix the way we get host when using DSN (issue #765) (Thanks: `François Pietka`_) +* Add missing keyword COLUMN after DROP (issue #769) (Thanks: `François Pietka`_) +* Don't include arguments in function suggestions for backslash commands (Thanks: `Joakim Koljonen`_) +* Optionally use POSTGRES_USER, POSTGRES_HOST POSTGRES_PASSWORD from environment (Thanks: `Dick Marinus`_) + +1.7.0 +===== + +* Refresh completions after `COMMIT` or `ROLLBACK`. (Thanks: `Irina Truong`_) +* Fixed DSN aliases not being read from custom pgclirc (issue #717). (Thanks: `Irina Truong`_). +* Use dbcli's Homebrew tap for installing pgcli on macOS (issue #718) (Thanks: `Thomas Roten`_). +* Only set `LESS` environment variable if it's unset. (Thanks: `Irina Truong`_) +* Quote schema in `SET SCHEMA` statement (issue #469) (Thanks: `Irina Truong`_) +* Include arguments in function suggestions (Thanks: `Joakim Koljonen`_) +* Use CLI Helpers for pretty printing query results (Thanks: `Thomas Roten`_). +* Skip serial columns when expanding * for `INSERT INTO foo(*` (Thanks: `Joakim Koljonen`_). +* Command line option to list databases (issue #206) (Thanks: `François Pietka`_) + +1.6.0 +===== + +Features: +--------- +* Add time option for prompt (Thanks: `Gustavo Castro`_) +* Suggest objects from all schemas (not just those in search_path) (Thanks: `Joakim Koljonen`_) +* Casing for column headers (Thanks: `Joakim Koljonen`_) +* Allow configurable character to be used for multi-line query continuations. (Thanks: `Owen Stephens`_) +* Completions after ORDER BY and DISTINCT now take account of table aliases. (Thanks: `Owen Stephens`_) +* Narrow keyword candidates based on previous keyword. (Thanks: `Étienne Bersac`_) +* Opening an external editor will edit the last-run query. (Thanks: `Thomas Roten`_) +* Support query options in postgres URIs such as ?sslcert=foo.pem (Thanks: `Alexander Schmolck`_) + +Bug fixes: +---------- +* Fixed external editor bug (issue #668). (Thanks: `Irina Truong`_). +* Standardize command line option names. (Thanks: `Russell Davies`_) +* Improve handling of ``lock_not_available`` error (issue #700). (Thanks: `Jackson Popkin `_) +* Fixed user option precedence (issue #697). (Thanks: `Irina Truong`_). + +Internal changes: +----------------- +* Run pep8 checks in travis (Thanks: `Irina Truong`_). +* Add pager wrapper for behave tests (Thanks: `Dick Marinus`_). +* Behave quit pgcli nicely (Thanks: `Dick Marinus`_). +* Behave test source command (Thanks: `Dick Marinus`_). +* Behave fix clean up. (Thanks: `Dick Marinus`_). +* Test using behave the tee command (Thanks: `Dick Marinus`_). +* Behave remove boiler plate code (Thanks: `Dick Marinus`_). +* Behave fix pgspecial update (Thanks: `Dick Marinus`_). +* Add behave to tox (Thanks: `Dick Marinus`_). + +1.5.1 +===== + +Features: +--------- +* Better suggestions when editing functions (Thanks: `Joakim Koljonen`_) +* Command line option for ``--less-chatty``. (Thanks: `tk`_) +* Added ``MATERIALIZED VIEW`` keywords. (Thanks: `Joakim Koljonen`_). + +Bug fixes: +---------- + +* Support unicode chars in expanded mode. (Thanks: `Amjith Ramanujam`_) +* Fixed "set_session cannot be used inside a transaction" when using dsn. (Thanks: `Irina Truong`_). + +1.5.0 +===== + +Features: +--------- +* Upgraded pgspecial to 1.7.0. (See `pgspecial changelog `_ for list of fixes) +* Add a new config setting to allow expandable mode (Thanks: `Jonathan Boudreau `_) +* Make pgcli prompt width short when the prompt is too long (Thanks: `Jonathan Virga `_) +* Add additional completion for ``ALTER`` keyword (Thanks: `Darik Gamble`_) +* Make the menu size configurable. (Thanks `Darik Gamble`_) + +Bug Fixes: +---------- +* Handle more connection failure cases. (Thanks: `Amjith Ramanujam`_) +* Fix the connection failure issues with latest psycopg2. (Thanks: `Amjith Ramanujam`_) + +Internal Changes: +----------------- + +* Add testing for Python 3.5 and 3.6. (Thanks: `Amjith Ramanujam`_) + +1.4.0 +===== + +Features: +--------- + +* Search table suggestions using initialisms. (Thanks: `Joakim Koljonen`_). +* Support for table-qualifying column suggestions. (Thanks: `Joakim Koljonen`_). +* Display transaction status in the toolbar. (Thanks: `Joakim Koljonen`_). +* Display vi mode in the toolbar. (Thanks: `Joakim Koljonen`_). +* Added --prompt option. (Thanks: `Irina Truong`_). + +Bug Fixes: +---------- + +* Fix scoping for columns from CTEs. (Thanks: `Joakim Koljonen`_) +* Fix crash after `with`. (Thanks: `Joakim Koljonen`_). +* Fix issue #603 (`\i` raises a TypeError). (Thanks: `Lele Gaifax`_). + + +Internal Changes: +----------------- + +* Set default data_formatting to nothing. (Thanks: `Amjith Ramanujam`_). +* Increased minimum prompt_toolkit requirement to 1.0.9. (Thanks: `Irina Truong`_). + + +1.3.1 +===== + +Bug Fixes: +---------- +* Fix a crashing bug due to sqlparse upgrade. (Thanks: `Darik Gamble`_) + + +1.3.0 +===== + +IMPORTANT: Python 2.6 is not officially supported anymore. + +Features: +--------- +* Add delimiters to displayed numbers. This can be configured via the config file. (Thanks: `Sergii`_). +* Fix broken 'SHOW ALL' in redshift. (Thanks: `Manuel Barkhau`_). +* Support configuring keyword casing preferences. (Thanks: `Darik Gamble`_). +* Add a new multi_line_mode option in config file. The values can be `psql` or `safe`. (Thanks: `Joakim Koljonen`_) + Setting ``multi_line_mode = safe`` will make sure that a query will only be executed when Alt+Enter is pressed. + +Bug Fixes: +---------- +* Fix crash bug with leading parenthesis. (Thanks: `Joakim Koljonen`_). +* Remove cumulative addition of timing data. (Thanks: `Amjith Ramanujam`_). +* Handle unrecognized keywords gracefully. (Thanks: `Darik Gamble`_) +* Use raw strings in regex specifiers. This preemptively fixes a crash in Python 3.6. (Thanks `Lele Gaifax`_) + +Internal Changes: +----------------- +* Set sqlparse version dependency to >0.2.0, <0.3.0. (Thanks: `Amjith Ramanujam`_). +* XDG_CONFIG_HOME support for config file location. (Thanks: `Fabien Meghazi`_). +* Remove Python 2.6 from travis test suite. (Thanks: `Amjith Ramanujam`_) + +1.2.0 +===== + +Features: +--------- + +* Add more specifiers to pgcli prompt. (Thanks: `Julien Rouhaud`_). + ``\p`` for port info ``\#`` for super user and ``\i`` for pid. +* Add `\watch` command to periodically execute a command. (Thanks: `Stuart Quin`_). + ``> SELECT * FROM django_migrations; \watch 1 /* Runs the command every second */`` +* Add command-line option --single-connection to prevent pgcli from using multiple connections. (Thanks: `Joakim Koljonen`_). +* Add priority to the suggestions to sort based on relevance. (Thanks: `Joakim Koljonen`_). +* Configurable null format via the config file. (Thanks: `Adrian Dries`_). +* Add support for CTE aware auto-completion. (Thanks: `Darik Gamble`_). +* Add host and user information to default pgcli prompt. (Thanks: `Lim H`_). +* Better scoping for tables in insert statements to improve suggestions. (Thanks: `Joakim Koljonen`_). + +Bug Fixes: +---------- + +* Do not install setproctitle on cygwin. (Thanks: `Janus Troelsen`_). +* Work around sqlparse crashing after AS keyword. (Thanks: `Joakim Koljonen`_). +* Fix a crashing bug with named queries. (Thanks: `Joakim Koljonen`_). +* Replace timestampz alias since AWS Redshift does not support it. (Thanks: `Tahir Butt`_). +* Prevent pgcli from hanging indefinitely when Postgres instance is not running. (Thanks: `Darik Gamble`_) + +Internal Changes: +----------------- + +* Upgrade to sqlparse-0.2.0. (Thanks: `Tiziano Müller`_). +* Upgrade to pgspecial 1.6.0. (Thanks: `Stuart Quin`_). + + +1.1.0 +===== + +Features: +--------- + +* Add support for ``\db`` command. (Thanks: `Irina Truong`_) + +Bugs: +----- + +* Fix the crash at startup while parsing the postgres url with port number. (Thanks: `Eric Wald`_) +* Fix the crash with Redshift databases. (Thanks: `Darik Gamble`_) + +Internal Changes: +----------------- + +* Upgrade pgspecial to 1.5.0 and above. + +1.0.0 +===== + +Features: +--------- + +* Upgrade to prompt-toolkit 1.0.0. (Thanks: `Jonathan Slenders`_). +* Add support for `\o` command to redirect query output to a file. (Thanks: `Tim Sanders`_). +* Add `\i` path completion. (Thanks: `Anthony Lai`_). +* Connect to a dsn saved in config file. (Thanks: `Rodrigo Ramírez Norambuena`_). +* Upgrade sqlparse requirement to version 0.1.19. (Thanks: `Fernando L. Canizo`_). +* Add timestamptz to DATE custom extension. (Thanks: `Fernando Mora`_). +* Ensure target dir exists when copying config. (Thanks: `David Szotten`_). +* Handle dates that fall in the B.C. range. (Thanks: `Stuart Quin`_). +* Pager is selected from config file or else from environment variable. (Thanks: `Fernando Mora`_). +* Add support for Amazon Redshift. (Thanks: `Timothy Cleaver`_). +* Add support for Postgres 8.x. (Thanks: `Timothy Cleaver`_ and `Darik Gamble`_) +* Don't error when completing parameter-less functions. (Thanks: `David Szotten`_). +* Concat and return all available notices. (Thanks: `Stuart Quin`_). +* Handle unicode in record type. (Thanks: `Amjith Ramanujam`_). +* Added humanized time display. Connect #396. (Thanks: `Irina Truong`_). +* Add EXPLAIN keyword to the completion list. (Thanks: `Amjith Ramanujam`_). +* Added sdist upload to release script. (Thanks: `Irina Truong`_). +* Sort completions based on most recently used. (Thanks: `Darik Gamble`) +* Expand '*' into column list during completion. This can be triggered by hitting `` after the '*' character in the sql while typing. (Thanks: `Joakim Koljonen`_) +* Add a limit to the warning about too many rows. This is controlled by a new config value in ~/.config/pgcli/config. (Thanks: `Anže Pečar`_) +* Improved argument list in function parameter completions. (Thanks: `Joakim Koljonen`_) +* Column suggestions after the COLUMN keyword. (Thanks: `Darik Gamble`_) +* Filter out trigger implemented functions from the suggestion list. (Thanks: `Daniel Rocco`_) +* State of the art JOIN clause completions that suggest entire conditions. (Thanks: `Joakim Koljonen`_) +* Suggest fully formed JOIN clauses based on Foreign Key relations. (Thanks: `Joakim Koljonen`_) +* Add support for `\dx` meta command to list the installed extensions. (Thanks: `Darik Gamble`_) +* Add support for `\copy` command. (Thanks: `Catherine Devlin`_) + +Bugs: +----- + +* Fix bug where config writing would leave a '~' dir. (Thanks: `James Munson`_). +* Fix auto-completion breaking for table names with caps. (Thanks: `Anthony Lai`_). +* Fix lexical ordering bug. (Thanks: `Anthony Lai`_). +* Use lexical order to break ties when fuzzy matching. (Thanks: `Daniel Rocco`_). +* Fix the bug in auto-expand mode when there are no rows to display. (Thanks: `Amjith Ramanujam`_). +* Fix broken `\i` after #395. (Thanks: `David Szotten`_). +* Fix multi-way joins in auto-completion. (Thanks: `Darik Gamble`_) +* Display null values as in expanded output. (Thanks: `Amjith Ramanujam`_). +* Robust support for Postgres version less than 9.x. (Thanks: `Darik Gamble`_) + +Internal Changes: +----------------- + +* Update config file location in README. (Thanks: `Ari Summer`_). +* Explicitly add wcwidth as a dependency. (Thanks: `Amjith Ramanujam`_). +* Add tests for the format_output. (Thanks: `Amjith Ramanujam`_). +* Lots of tests for pgcompleter. (Thanks: `Darik Gamble`_). +* Update pgspecial dependency to 1.4.0. + + +0.20.1 +====== + +Bug Fixes: +---------- +* Fixed logging in Windows by switching the location of log and history file based on OS. (Thanks: Amjith, `Darik Gamble`_, `Irina Truong`_). + +0.20.0 +====== + +Features: +--------- +* Perform auto-completion refresh in background. (Thanks: Amjith, `Darik Gamble`_, `Irina Truong`_). + When the auto-completion entries are refreshed, the update now happens in a + background thread. This means large databases with thousands of tables are + handled without blocking. +* Add ``CONCURRENTLY`` to keyword completion. (Thanks: `Johannes Hoff`_). +* Add support for ``\h`` command. (Thanks: `Stuart Quin`_). + This is a huge deal. Users can now get help on an SQL command by typing: + ``\h COMMAND_NAME`` in the pgcli prompt. +* Add support for ``\x auto``. (Thanks: `Stuart Quin`_). + ``\\x auto`` will automatically switch to expanded mode if the output is wider + than the display window. +* Don't hide functions from pg_catalog. (Thanks: `Darik Gamble`_). +* Suggest set-returning functions as tables. (Thanks: `Darik Gamble`_). + Functions that return table like results will now be suggested in places of tables. +* Suggest fields from functions used as tables. (Thanks: `Darik Gamble`_). +* Using ``pgspecial`` as a separate module. (Thanks: `Irina Truong`_). +* Make "enter" key behave as "tab" key when the completion menu is displayed. (Thanks: `Matheus Rosa`_). +* Support different error-handling options when running multiple queries. (Thanks: `Darik Gamble`_). + When ``on_error = STOP`` in the config file, pgcli will abort execution if one of the queries results in an error. +* Hide the password displayed in the process name in ``ps``. (Thanks: `Stuart Quin`_) + +Bug Fixes: +---------- +* Fix the ordering bug in `\\d+` display, this bug was displaying the wrong table name in the reference. (Thanks: `Tamas Boros`_). +* Only show expanded layout if valid list of headers provided. (Thanks: `Stuart Quin`_). +* Fix suggestions in compound join clauses. (Thanks: `Darik Gamble`_). +* Fix completion refresh in multiple query scenario. (Thanks: `Darik Gamble`_). +* Fix the broken timing information. +* Fix the removal of whitespaces in the output. (Thanks: `Jacek Wielemborek`_) +* Fix PyPI badge. (Thanks: `Artur Dryomov`_). + +Improvements: +------------- +* Move config file to `~/.config/pgcli/config` instead of `~/.pgclirc` (Thanks: `inkn`_). +* Move literal definitions to standalone JSON files. (Thanks: `Darik Gamble`_). + +Internal Changes: +----------------- +* Improvements to integration tests to make it more robust. (Thanks: `Irina Truong`_). + +0.19.2 +====== + +Features: +--------- + +* Autocompletion for database name in \c and \connect. (Thanks: `Darik Gamble`_). +* Improved multiline query support by correctly handling open quotes. (Thanks: `Darik Gamble`_). +* Added \pager command. +* Enhanced \i to run multiple queries and display the results for each of them +* Added keywords to suggestions after WHERE clause. +* Enabled autocompletion in named queries. (Thanks: `Irina Truong`_). +* Path to .pgclirc can be specified in command line. (Thanks: `Irina Truong`_). +* Added support for pg_service_conf file. (Thanks: `Irina Truong`_). +* Added custom styles. (Contributor: `Darik Gamble`_). + +Internal Changes: +----------------- + +* More completer test cases. (Thanks: `Darik Gamble`_). +* Updated sqlparse version from 0.1.14 to 0.1.16. (Thanks: `Darik Gamble`_). +* Upgraded to prompt_toolkit 0.46. (Thanks: `Jonathan Slenders`_). + +BugFixes: +--------- +* Fixed the completer crashing on invalid SQL. (Thanks: `Darik Gamble`_). +* Fixed unicode issues, updated tests and fixed broken tests. + +0.19.1 +====== + +BugFixes: +--------- + +* Fix an autocompletion bug that was crashing the completion engine when unknown keyword is entered. (Thanks: `Darik Gamble`_) + +0.19.0 +====== + +Features: +--------- + +* Wider completion menus can be enabled via the config file. (Thanks: `Jonathan Slenders`_) + + Open the config file (~/.pgclirc) and check if you have + ``wider_completion_menu`` option available. If not add it in and set it to + ``True``. + +* Completion menu now has metadata information such as schema, table, column, view, etc., next to the suggestions. (Thanks: `Darik Gamble`_) +* Customizable history file location via config file. (Thanks: `Çağatay Yüksel`_) + + Add this line to your config file (~/.pgclirc) to customize where to store the history file. + +:: + + history_file = /path/to/history/file + +* Add support for running queries from a file using ``\i`` special command. (Thanks: `Michael Kaminsky`_) + +BugFixes: +--------- + +* Always use utf-8 for database encoding regardless of the default encoding used by the database. +* Fix for None dereference on ``\d schemaname.`` with sequence. (Thanks: `Nathan Jhaveri`_) +* Fix a crashing bug in the autocompletion engine for some ``JOIN`` queries. +* Handle KeyboardInterrupt in pager and not quit pgcli as a consequence. + +Internal Changes: +----------------- + +* Added more behaviorial tests (Thanks: `Irina Truong`_) +* Added code coverage to the tests. (Thanks: `Irina Truong`_) +* Run behaviorial tests as part of TravisCI (Thanks: `Irina Truong`_) +* Upgraded prompt_toolkit version to 0.45 (Thanks: `Jonathan Slenders`_) +* Update the minimum required version of click to 4.1. + +0.18.0 +====== + +Features: +--------- + +* Add fuzzy matching for the table names and column names. + + Matching very long table/column names are now easier with fuzzy matching. The + fuzzy match works like the fuzzy open in SublimeText or Vim's Ctrl-P plugin. + + eg: Typing ``djmv`` will match `django_migration_views` since it is able to + match parts of the input to the full table name. + +* Change the timing information to seconds. + + The ``Command Time`` and ``Format Time`` are now displayed in seconds instead + of a unitless number displayed in scientific notation. + +* Support for named queries (favorite queries). (Thanks: `Brett Atoms`_) + + Frequently typed queries can now be saved and recalled using a name using + newly added special commands (``\n[+]``, ``\ns``, ``\nd``). + + eg: + +:: + + # Save a query + pgcli> \ns simple select * from foo + saved + + # List all saved queries + pgcli> \n+ + + # Execute a saved query + pgcli> \n simple + + # Delete a saved query + pgcli> \nd simple + +* Pasting queries into the pgcli repl is orders of magnitude faster. (Thanks: `Jonathan Slenders`_) + +* Add support for PGPASSWORD environment variable to pass the password for the + postgres database. (Thanks: `Irina Truong`_) + +* Add the ability to manually refresh autocompletions by typing ``\#`` or + ``\refresh``. This is useful if the database was updated by an external means + and you'd like to refresh the auto-completions to pick up the new change. + +Bug Fixes: +---------- + +* Fix an error when running ``\d table_name`` when running on a table with rules. (Thanks: `Ali Kargın`_) +* Fix a pgcli crash when entering non-ascii characters in Windows. (Thanks: `Darik Gamble`_, `Jonathan Slenders`_) +* Faster rendering of expanded mode output by making the horizontal separator a fixed length string. +* Completion suggestions for the ``\c`` command are not auto-escaped by default. + +Internal Changes: +----------------- + +* Complete refactor of handling the back-slash commands. +* Upgrade prompt_toolkit to 0.42. (Thanks: `Jonathan Slenders`_) +* Change the config file management to use ConfigObj.(Thanks: `Brett Atoms`_) +* Add integration tests using ``behave``. (Thanks: `Irina Truong`_) + +0.17.0 +====== + +Features: +--------- + +* Add support for auto-completing view names. (Thanks: `Darik Gamble`_) +* Add support for building RPM and DEB packages. (Thanks: dp_) +* Add subsequence matching for completion. (Thanks: `Daniel Rocco`_) + Previously completions only matched a table name if it started with the + partially typed word. Now completions will match even if the partially typed + word is in the middle of a suggestion. + eg: When you type 'mig', 'django_migrations' will be suggested. +* Completion for built-in tables and temporary tables are suggested after entering a prefix of ``pg_``. (Thanks: `Darik Gamble`_) +* Add place holder doc strings for special commands that are planned for implementation. (Thanks: `Irina Truong`_) +* Updated version of prompt_toolkit, now matching braces are highlighted. (Thanks: `Jonathan Slenders`_) +* Added support of ``\\e`` command. Queries can be edited in an external editor. (Thanks: `Irina Truong`_) + eg: When you type ``SELECT * FROM \e`` it will be opened in an external editor. +* Add special command ``\dT`` to show datatypes. (Thanks: `Darik Gamble`_) +* Add auto-completion support for datatypes in CREATE, SELECT etc. (Thanks: `Darik Gamble`_) +* Improve the auto-completion in WHERE clause with logical operators. (Thanks: `Darik Gamble`_) +* + +Bug Fixes: +---------- + +* Fix the table formatting while printing multi-byte characters (Chinese, Japanese etc). (Thanks: `蔡佳男`_) +* Fix a crash when pg_catalog was present in search path. (Thanks: `Darik Gamble`_) +* Fixed a bug that broke `\\e` when prompt_tookit was updated. (Thanks: `François Pietka`_) +* Fix the display of triggers as shown in the ``\d`` output. (Thanks: `Dimitar Roustchev`_) +* Fix broken auto-completion for INNER JOIN, LEFT JOIN etc. (Thanks: `Darik Gamble`_) +* Fix incorrect super() calls in pgbuffer, pgtoolbar and pgcompleter. No change in functionality but protects against future problems. (Thanks: `Daniel Rocco`_) +* Add missing schema completion for CREATE and DROP statements. (Thanks: `Darik Gamble`_) +* Minor fixes around cursor cleanup. + +0.16.3 +====== + +Bug Fixes: +---------- +* Add more SQL keywords for auto-complete suggestion. +* Messages raised as part of stored procedures are no longer ignored. +* Use postgres flavored syntax highlighting instead of generic ANSI SQL. + +0.16.2 +====== + +Bug Fixes: +---------- +* Fix a bug where the schema qualifier was ignored by the auto-completion. + As a result the suggestions for tables vs functions are cleaner. (Thanks: `Darik Gamble`_) +* Remove scientific notation when formatting large numbers. (Thanks: `Daniel Rocco`_) +* Add the FUNCTION keyword to auto-completion. +* Display NULL values as instead of empty strings. +* Fix the completion refresh when ``\connect`` is executed. + +0.16.1 +====== + +Bug Fixes: +---------- +* Fix unicode issues with hstore. +* Fix a silent error when database is changed using \\c. + +0.16.0 +====== + +Features: +--------- +* Add \ds special command to show sequences. +* Add Vi mode for keybindings. This can be enabled by adding 'vi = True' in ~/.pgclirc. (Thanks: `Jay Zeng`_) +* Add a -v/--version flag to pgcli. +* Add completion for TEMPLATE keyword and smart-completion for + 'CREATE DATABASE blah WITH TEMPLATE '. (Thanks: `Daniel Rocco`_) +* Add custom decoders to json/jsonb to emulate the behavior of psql. This + removes the unicode prefix (eg: u'Éowyn') in the output. (Thanks: `Daniel Rocco`_) +* Add \df special command to show functions. (Thanks: `Darik Gamble`_) +* Make suggestions for special commands smarter. eg: \dn - only suggests schemas. (Thanks: `Darik Gamble`_) +* Print out the version and other meta info about pgcli at startup. + +Bug Fixes: +---------- +* Fix a rare crash caused by adding new schemas to a database. (Thanks: `Darik Gamble`_) +* Make \dt command honor the explicit schema specified in the arg. (Thanks: `Darik Gamble`_) +* Print BIGSERIAL type as Integer instead of Float. +* Show completions for special commands at the beginning of a statement. (Thanks: `Daniel Rocco`_) +* Allow special commands to work in a multi-statement case where multiple sql + statements are separated by semi-colon in the same line. + +0.15.4 +====== +* Dummy version to replace accidental PyPI entry deletion. + +0.15.3 +====== +* Override the LESS options completely instead of appending to it. + +0.15.2 +====== +* Revert back to using psycopg2 as the postgres adapter. psycopg2cffi fails for some tests in Python 3. + +0.15.0 +====== + +Features: +--------- +* Add syntax color styles to config. +* Add auto-completion for COPY statements. +* Change Postgres adapter to psycopg2cffi, to make it PyPy compatible. + Now pgcli can be run by PyPy. + +Bug Fixes: +---------- +* Treat boolean values as strings instead of ints. +* Make \di, \dv and \dt to be schema aware. (Thanks: `Darik Gamble`_) +* Make column name display unicode compatible. + +0.14.0 +====== + +Features: +--------- +* Add alias completion support to ON keyword. (Thanks: `Irina Truong`_) +* Add LIMIT keyword to completion. +* Auto-completion for Postgres schemas. (Thanks: `Darik Gamble`_) +* Better unicode handling for datatypes, dbname and roles. +* Add \timing command to time the sql commands. + This can be set via config file (~/.pgclirc) using `timing = True`. +* Add different table styles for displaying output. + This can be changed via config file (~/.pgclirc) using `table_format = fancy_grid`. +* Add confirmation before printing results that have more than 1000 rows. + +Bug Fixes: +---------- + +* Performance improvements to expanded view display (\x). +* Cast bytea files to text while displaying. (Thanks: `Daniel Rocco`_) +* Added a list of reserved words that should be auto-escaped. +* Auto-completion is now case-insensitive. +* Fix the broken completion for multiple sql statements. (Thanks: `Darik Gamble`_) + +0.13.0 +====== + +Features: +--------- + +* Add -d/--dbname option to the commandline. + eg: pgcli -d database +* Add the username as an argument after the database. + eg: pgcli dbname user + +Bug Fixes: +---------- +* Fix the crash when \c fails. +* Fix the error thrown by \d when triggers are present. +* Fix broken behavior on \?. (Thanks: `Darik Gamble`_) + +0.12.0 +====== + +Features: +--------- + +* Upgrade to prompt_toolkit version 0.26 (Thanks: https://github.com/macobo) + * Adds Ctrl-left/right to move the cursor one word left/right respectively. + * Internal API changes. +* IPython integration through `ipython-sql`_ (Thanks: `Darik Gamble`_) + * Add an ipython magic extension to embed pgcli inside ipython. + * Results from a pgcli query are sent back to ipython. +* Multiple sql statements in the same line separated by semi-colon. (Thanks: https://github.com/macobo) + +.. _`ipython-sql`: https://github.com/catherinedevlin/ipython-sql + +Bug Fixes: +---------- + +* Fix 'message' attribute not found exception in Python 3. (Thanks: https://github.com/GMLudo) +* Use the database username as the database name instead of defaulting to OS username. (Thanks: https://github.com/fpietka) +* Auto-completion for auto-escaped column/table names. +* Fix i-reverse-search to work in prompt_toolkit version 0.26. + +0.11.0 +====== + +Features: +--------- + +* Add \dn command. (Thanks: https://github.com/CyberDem0n) +* Add \x command. (Thanks: https://github.com/stuartquin) +* Auto-escape special column/table names. (Thanks: https://github.com/qwesda) +* Cancel a command using Ctrl+C. (Thanks: https://github.com/macobo) +* Faster startup by reading all columns and tables in a single query. (Thanks: https://github.com/macobo) +* Improved psql compliance with env vars and password prompting. (Thanks: `Darik Gamble`_) +* Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: https://github.com/pabloab). + +Bug Fixes: +---------- +* Fix the broken behavior of \d+. (Thanks: https://github.com/macobo) +* Fix a crash during auto-completion. (Thanks: https://github.com/Erethon) +* Avoid losing pre_run_callables on error in editing. (Thanks: https://github.com/catherinedevlin) + +Improvements: +------------- +* Faster test runs on TravisCI. (Thanks: https://github.com/macobo) +* Integration tests with Postgres!! (Thanks: https://github.com/macobo) + +.. _`Amjith Ramanujam`: https://blog.amjith.com +.. _`Andrew Kuchling`: https://github.com/akuchling +.. _`Darik Gamble`: https://github.com/darikg +.. _`Daniel Rocco`: https://github.com/drocco007 +.. _`Jay Zeng`: https://github.com/jayzeng +.. _`蔡佳男`: https://github.com/xalley +.. _dp: https://github.com/ceocoder +.. _`Jonathan Slenders`: https://github.com/jonathanslenders +.. _`Dimitar Roustchev`: https://github.com/droustchev +.. _`François Pietka`: https://github.com/fpietka +.. _`Ali Kargın`: https://github.com/sancopanco +.. _`Brett Atoms`: https://github.com/brettatoms +.. _`Nathan Jhaveri`: https://github.com/nathanjhaveri +.. _`Çağatay Yüksel`: https://github.com/cagatay +.. _`Michael Kaminsky`: https://github.com/mikekaminsky +.. _`inkn`: inkn +.. _`Johannes Hoff`: Johannes Hoff +.. _`Matheus Rosa`: Matheus Rosa +.. _`Artur Dryomov`: https://github.com/ming13 +.. _`Stuart Quin`: https://github.com/stuartquin +.. _`Tamas Boros`: https://github.com/TamasNo1 +.. _`Jacek Wielemborek`: https://github.com/d33tah +.. _`Rodrigo Ramírez Norambuena`: https://github.com/roramirez +.. _`Anthony Lai`: https://github.com/ajlai +.. _`Ari Summer`: Ari Summer +.. _`David Szotten`: David Szotten +.. _`Fernando L. Canizo`: Fernando L. Canizo +.. _`Tim Sanders`: https://github.com/Gollum999 +.. _`Irina Truong`: https://github.com/j-bennet +.. _`James Munson`: https://github.com/jmunson +.. _`Fernando Mora`: https://github.com/fernandomora +.. _`Timothy Cleaver`: Timothy Cleaver +.. _`gtxx`: gtxx +.. _`Joakim Koljonen`: https://github.com/koljonen +.. _`Anže Pečar`: https://github.com/Smotko +.. _`Catherine Devlin`: https://github.com/catherinedevlin +.. _`Eric Wald`: https://github.com/eswald +.. _`avdd`: https://github.com/avdd +.. _`Adrian Dries`: Adrian Dries +.. _`Julien Rouhaud`: https://github.com/rjuju +.. _`Lim H`: Lim H +.. _`Tahir Butt`: Tahir Butt +.. _`Tiziano Müller`: https://github.com/dev-zero +.. _`Janus Troelsen`: https://github.com/ysangkok +.. _`Fabien Meghazi`: https://github.com/amigrave +.. _`Manuel Barkhau`: https://github.com/mbarkhau +.. _`Sergii`: https://github.com/foxyterkel +.. _`Lele Gaifax`: https://github.com/lelit +.. _`tk`: https://github.com/kanet77 +.. _`Owen Stephens`: https://github.com/owst +.. _`Russell Davies`: https://github.com/russelldavies +.. _`Dick Marinus`: https://github.com/meeuw +.. _`Étienne Bersac`: https://github.com/bersace +.. _`Thomas Roten`: https://github.com/tsroten +.. _`Gustavo Castro`: https://github.com/gustavo-castro +.. _`Alexander Schmolck`: https://github.com/aschmolck +.. _`Andrew Speed`: https://github.com/AndrewSpeed +.. _`Dmitry B`: https://github.com/oxitnik +.. _`Marcin Sztolcman`: https://github.com/msztolcman +.. _`Isank`: https://github.com/isank +.. _`Bojan Delić`: https://github.com/delicb +.. _`Frederic Aoustin`: https://github.com/fraoustin +.. _`Jason Ribeiro`: https://github.com/jrib +.. _`Rishi Ramraj`: https://github.com/RishiRamraj +.. _`Matthieu Guilbert`: https://github.com/gma2th +.. _`Alexandr Korsak`: https://github.com/oivoodoo +.. _`Saif Hakim`: https://github.com/saifelse +.. _`Artur Balabanov`: https://github.com/arturbalabanov +.. _`Kenny Do`: https://github.com/kennydo +.. _`Max Rothman`: https://github.com/maxrothman +.. _`Daniel Egger`: https://github.com/DanEEStar +.. _`Ignacio Campabadal`: https://github.com/igncampa +.. _`Mikhail Elovskikh`: https://github.com/wronglink +.. _`Marcin Cieślak`: https://github.com/saper +.. _`Scott Brenstuhl`: https://github.com/808sAndBR +.. _`easteregg`: https://github.com/verfriemelt-dot-org +.. _`Nathan Verzemnieks`: https://github.com/njvrzm +.. _`raylu`: https://github.com/raylu +.. _`Zhaolong Zhu`: https://github.com/zzl0 +.. _`Xavier Francisco`: https://github.com/Qu4tro +.. _`VVelox`: https://github.com/VVelox +.. _`Telmo "Trooper"`: https://github.com/telmotrooper +.. _`Alexander Zawadzki`: https://github.com/zadacka +.. _`Sebastian Janko`: https://github.com/sebojanko +.. _`Pedro Ferrari`: https://github.com/petobens +.. _`BrownShibaDog`: https://github.com/BrownShibaDog +.. _`thegeorgeous`: https://github.com/thegeorgeous +.. _`laixintao`: https://github.com/laixintao +.. _`anthonydb`: https://github.com/anthonydb +.. _`Daniel Kukula`: https://github.com/dkuku diff --git a/pgcli-completion.bash b/pgcli-completion.bash new file mode 100644 index 0000000..3549b56 --- /dev/null +++ b/pgcli-completion.bash @@ -0,0 +1,61 @@ +_pg_databases() +{ + # -w was introduced in 8.4, https://launchpad.net/bugs/164772 + # "Access privileges" in output may contain linefeeds, hence the NF > 1 + COMPREPLY=( $( compgen -W "$( psql -AtqwlF $'\t' 2>/dev/null | \ + awk 'NF > 1 { print $1 }' )" -- "$cur" ) ) +} + +_pg_users() +{ + # -w was introduced in 8.4, https://launchpad.net/bugs/164772 + COMPREPLY=( $( compgen -W "$( psql -Atqwc 'select usename from pg_user' \ + template1 2>/dev/null )" -- "$cur" ) ) + [[ ${#COMPREPLY[@]} -eq 0 ]] && COMPREPLY=( $( compgen -u -- "$cur" ) ) +} + +_pgcli() +{ + local cur prev words cword + _init_completion -s || return + + case $prev in + -h|--host) + _known_hosts_real "$cur" + return 0 + ;; + -U|--user) + _pg_users + return 0 + ;; + -d|--dbname) + _pg_databases + return 0 + ;; + --help|-v|--version|-p|--port|-R|--row-limit) + # all other arguments are noop with these + return 0 + ;; + esac + + case "$cur" in + --*) + # return list of available options + COMPREPLY=( $( compgen -W '--host --port --user --password --no-password + --single-connection --version --dbname --pgclirc --dsn + --row-limit --help' -- "$cur" ) ) + [[ $COMPREPLY == *= ]] && compopt -o nospace + return 0 + ;; + -) + # only complete long options + compopt -o nospace + COMPREPLY=( -- ) + return 0 + ;; + *) + # return list of available databases + _pg_databases + esac +} && +complete -F _pgcli pgcli diff --git a/pgcli/__init__.py b/pgcli/__init__.py new file mode 100644 index 0000000..76ad18b --- /dev/null +++ b/pgcli/__init__.py @@ -0,0 +1 @@ +__version__ = "4.0.1" diff --git a/pgcli/__main__.py b/pgcli/__main__.py new file mode 100644 index 0000000..ddf1662 --- /dev/null +++ b/pgcli/__main__.py @@ -0,0 +1,9 @@ +""" +pgcli package main entry point +""" + +from .main import cli + + +if __name__ == "__main__": + cli() diff --git a/pgcli/auth.py b/pgcli/auth.py new file mode 100644 index 0000000..2f1e552 --- /dev/null +++ b/pgcli/auth.py @@ -0,0 +1,60 @@ +import click +from textwrap import dedent + + +keyring = None # keyring will be loaded later + + +keyring_error_message = dedent( + """\ + {} + {} + To remove this message do one of the following: + - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/ + - uninstall keyring: pip uninstall keyring + - disable keyring in our configuration: add keyring = False to [main]""" +) + + +def keyring_initialize(keyring_enabled, *, logger): + """Initialize keyring only if explicitly enabled""" + global keyring + + if keyring_enabled: + # Try best to load keyring (issue #1041). + import importlib + + try: + keyring = importlib.import_module("keyring") + except ( + ModuleNotFoundError + ) as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 + logger.warning("import keyring failed: %r.", e) + + +def keyring_get_password(key): + """Attempt to get password from keyring""" + # Find password from store + passwd = "" + try: + passwd = keyring.get_password("pgcli", key) or "" + except Exception as e: + click.secho( + keyring_error_message.format( + "Load your password from keyring returned:", str(e) + ), + err=True, + fg="red", + ) + return passwd + + +def keyring_set_password(key, passwd): + try: + keyring.set_password("pgcli", key, passwd) + except Exception as e: + click.secho( + keyring_error_message.format("Set password in keyring returned:", str(e)), + err=True, + fg="red", + ) diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py new file mode 100644 index 0000000..c887cb6 --- /dev/null +++ b/pgcli/completion_refresher.py @@ -0,0 +1,152 @@ +import threading +import os +from collections import OrderedDict + +from .pgcompleter import PGCompleter + + +class CompletionRefresher: + refreshers = OrderedDict() + + def __init__(self): + self._completer_thread = None + self._restart_refresh = threading.Event() + + def refresh(self, executor, special, callbacks, history=None, settings=None): + """ + Creates a PGCompleter object and populates it with the relevant + completion suggestions in a background thread. + + executor - PGExecute object, used to extract the credentials to connect + to the database. + special - PGSpecial object used for creating a new completion object. + settings - dict of settings for completer object + callbacks - A function or a list of functions to call after the thread + has completed the refresh. The newly created completion + object will be passed in as an argument to each callback. + """ + if executor.is_virtual_database(): + # do nothing + return [(None, None, None, "Auto-completion refresh can't be started.")] + + if self.is_refreshing(): + self._restart_refresh.set() + return [(None, None, None, "Auto-completion refresh restarted.")] + else: + self._completer_thread = threading.Thread( + target=self._bg_refresh, + args=(executor, special, callbacks, history, settings), + name="completion_refresh", + ) + self._completer_thread.daemon = True + self._completer_thread.start() + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] + + def is_refreshing(self): + return self._completer_thread and self._completer_thread.is_alive() + + def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None): + settings = settings or {} + completer = PGCompleter( + smart_completion=True, pgspecial=special, settings=settings + ) + + if settings.get("single_connection"): + executor = pgexecute + else: + # Create a new pgexecute method to populate the completions. + executor = pgexecute.copy() + # If callbacks is a single function then push it into a list. + if callable(callbacks): + callbacks = [callbacks] + + while 1: + for refresher in self.refreshers.values(): + refresher(completer, executor) + if self._restart_refresh.is_set(): + self._restart_refresh.clear() + break + else: + # Break out of while loop if the for loop finishes natually + # without hitting the break statement. + break + + # Start over the refresh from the beginning if the for loop hit the + # break statement. + continue + + # Load history into pgcompleter so it can learn user preferences + n_recent = 100 + if history: + for recent in history.get_strings()[-n_recent:]: + completer.extend_query_history(recent, is_init=True) + + for callback in callbacks: + callback(completer) + + if not settings.get("single_connection") and executor.conn: + # close connection established with pgexecute.copy() + executor.conn.close() + + +def refresher(name, refreshers=CompletionRefresher.refreshers): + """Decorator to populate the dictionary of refreshers with the current + function. + """ + + def wrapper(wrapped): + refreshers[name] = wrapped + return wrapped + + return wrapper + + +@refresher("schemata") +def refresh_schemata(completer, executor): + completer.set_search_path(executor.search_path()) + completer.extend_schemata(executor.schemata()) + + +@refresher("tables") +def refresh_tables(completer, executor): + completer.extend_relations(executor.tables(), kind="tables") + completer.extend_columns(executor.table_columns(), kind="tables") + completer.extend_foreignkeys(executor.foreignkeys()) + + +@refresher("views") +def refresh_views(completer, executor): + completer.extend_relations(executor.views(), kind="views") + completer.extend_columns(executor.view_columns(), kind="views") + + +@refresher("types") +def refresh_types(completer, executor): + completer.extend_datatypes(executor.datatypes()) + + +@refresher("databases") +def refresh_databases(completer, executor): + completer.extend_database_names(executor.databases()) + + +@refresher("casing") +def refresh_casing(completer, executor): + casing_file = completer.casing_file + if not casing_file: + return + generate_casing_file = completer.generate_casing_file + if generate_casing_file and not os.path.isfile(casing_file): + casing_prefs = "\n".join(executor.casing()) + with open(casing_file, "w") as f: + f.write(casing_prefs) + if os.path.isfile(casing_file): + with open(casing_file) as f: + completer.extend_casing([line.strip() for line in f]) + + +@refresher("functions") +def refresh_functions(completer, executor): + completer.extend_functions(executor.functions()) diff --git a/pgcli/config.py b/pgcli/config.py new file mode 100644 index 0000000..22f08dc --- /dev/null +++ b/pgcli/config.py @@ -0,0 +1,99 @@ +import errno +import shutil +import os +import platform +from os.path import expanduser, exists, dirname +import re +from typing import TextIO +from configobj import ConfigObj + + +def config_location(): + if "XDG_CONFIG_HOME" in os.environ: + return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) + elif platform.system() == "Windows": + return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\" + else: + return expanduser("~/.config/pgcli/") + + +def load_config(usr_cfg, def_cfg=None): + # avoid config merges when possible. For writing, we need an umerged config instance. + # see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171 + if def_cfg: + cfg = ConfigObj() + cfg.merge(ConfigObj(def_cfg, interpolation=False)) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + else: + cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8") + cfg.filename = expanduser(usr_cfg) + return cfg + + +def ensure_dir_exists(path): + parent_dir = expanduser(dirname(path)) + os.makedirs(parent_dir, exist_ok=True) + + +def write_default_config(source, destination, overwrite=False): + destination = expanduser(destination) + if not overwrite and exists(destination): + return + + ensure_dir_exists(destination) + + shutil.copyfile(source, destination) + + +def upgrade_config(config, def_config): + cfg = load_config(config, def_config) + cfg.write() + + +def get_config_filename(pgclirc_file=None): + return pgclirc_file or "%sconfig" % config_location() + + +def get_config(pgclirc_file=None): + from pgcli import __file__ as package_root + + package_root = os.path.dirname(package_root) + + pgclirc_file = get_config_filename(pgclirc_file) + + default_config = os.path.join(package_root, "pgclirc") + write_default_config(default_config, pgclirc_file) + + return load_config(pgclirc_file, default_config) + + +def get_casing_file(config): + casing_file = config["main"]["casing_file"] + if casing_file == "default": + casing_file = config_location() + "casing" + return casing_file + + +def skip_initial_comment(f_stream: TextIO) -> int: + """ + Initial comment in ~/.pg_service.conf is not always marked with '#' + which crashes the parser. This function takes a file object and + "rewinds" it to the beginning of the first section, + from where on it can be parsed safely + + :return: number of skipped lines + """ + section_regex = r"\s*\[" + pos = f_stream.tell() + lines_skipped = 0 + while True: + line = f_stream.readline() + if line == "": + break + if re.match(section_regex, line) is not None: + f_stream.seek(pos) + break + else: + pos += len(line) + lines_skipped += 1 + return lines_skipped diff --git a/pgcli/explain_output_formatter.py b/pgcli/explain_output_formatter.py new file mode 100644 index 0000000..ce45b4f --- /dev/null +++ b/pgcli/explain_output_formatter.py @@ -0,0 +1,19 @@ +from pgcli.pyev import Visualizer +import json + + +"""Explain response output adapter""" + + +class ExplainOutputFormatter: + def __init__(self, max_width): + self.max_width = max_width + + def format_output(self, cur, headers, **output_kwargs): + # explain query results should always contain 1 row each + [(data,)] = list(cur) + explain_list = json.loads(data) + visualizer = Visualizer(self.max_width) + for explain in explain_list: + visualizer.load(explain) + yield visualizer.get_list() diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py new file mode 100644 index 0000000..9c016f7 --- /dev/null +++ b/pgcli/key_bindings.py @@ -0,0 +1,133 @@ +import logging +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.filters import ( + completion_is_selected, + is_searching, + has_completions, + has_selection, + vi_mode, +) + +from .pgbuffer import buffer_should_be_handled, safe_multi_line_mode + +_logger = logging.getLogger(__name__) + + +def pgcli_bindings(pgcli): + """Custom key bindings for pgcli.""" + kb = KeyBindings() + + tab_insert_text = " " * 4 + + @kb.add("f2") + def _(event): + """Enable/Disable SmartCompletion Mode.""" + _logger.debug("Detected F2 key.") + pgcli.completer.smart_completion = not pgcli.completer.smart_completion + + @kb.add("f3") + def _(event): + """Enable/Disable Multiline Mode.""" + _logger.debug("Detected F3 key.") + pgcli.multi_line = not pgcli.multi_line + + @kb.add("f4") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F4 key.") + pgcli.vi_mode = not pgcli.vi_mode + event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS + + @kb.add("f5") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F5 key.") + pgcli.explain_mode = not pgcli.explain_mode + + @kb.add("tab") + def _(event): + """Force autocompletion at cursor on non-empty lines.""" + + _logger.debug("Detected key.") + + buff = event.app.current_buffer + doc = buff.document + + if doc.on_first_line or doc.current_line.strip(): + if buff.complete_state: + buff.complete_next() + else: + buff.start_completion(select_first=True) + else: + buff.insert_text(tab_insert_text, fire_event=False) + + @kb.add("escape", filter=has_completions) + def _(event): + """Force closing of autocompletion.""" + _logger.debug("Detected key.") + + event.current_buffer.complete_state = None + event.app.current_buffer.complete_state = None + + @kb.add("c-space") + def _(event): + """ + Initialize autocompletion at cursor. + + If the autocompletion menu is not showing, display it with the + appropriate completions for the context. + + If the menu is showing, select the next completion. + """ + _logger.debug("Detected key.") + + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=False) + + @kb.add("enter", filter=completion_is_selected) + def _(event): + """Makes the enter key work as the tab key only when showing the menu. + + In other words, don't execute query when enter is pressed in + the completion dropdown menu, instead close the dropdown menu + (accept current selection). + + """ + _logger.debug("Detected enter key during completion selection.") + + event.current_buffer.complete_state = None + event.app.current_buffer.complete_state = None + + # When using multi_line input mode the buffer is not handled on Enter (a new line is + # inserted instead), so we force the handling if we're not in a completion or + # history search, and one of several conditions are True + @kb.add( + "enter", + filter=~(completion_is_selected | is_searching) + & buffer_should_be_handled(pgcli), + ) + def _(event): + _logger.debug("Detected enter key.") + event.current_buffer.validate_and_handle() + + @kb.add("escape", "enter", filter=~vi_mode & ~safe_multi_line_mode(pgcli)) + def _(event): + """Introduces a line break regardless of multi-line mode or not.""" + _logger.debug("Detected alt-enter key.") + event.app.current_buffer.insert_text("\n") + + @kb.add("c-p", filter=~has_selection) + def _(event): + """Move up in history.""" + event.current_buffer.history_backward(count=event.arg) + + @kb.add("c-n", filter=~has_selection) + def _(event): + """Move down in history.""" + event.current_buffer.history_forward(count=event.arg) + + return kb diff --git a/pgcli/magic.py b/pgcli/magic.py new file mode 100644 index 0000000..09902a2 --- /dev/null +++ b/pgcli/magic.py @@ -0,0 +1,71 @@ +from .main import PGCli +import sql.parse +import sql.connection +import logging + +_logger = logging.getLogger(__name__) + + +def load_ipython_extension(ipython): + """This is called via the ipython command '%load_ext pgcli.magic'""" + + # first, load the sql magic if it isn't already loaded + if not ipython.find_line_magic("sql"): + ipython.run_line_magic("load_ext", "sql") + + # register our own magic + ipython.register_magic_function(pgcli_line_magic, "line", "pgcli") + + +def pgcli_line_magic(line): + _logger.debug("pgcli magic called: %r", line) + parsed = sql.parse.parse(line, {}) + # "get" was renamed to "set" in ipython-sql: + # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43 + if hasattr(sql.connection.Connection, "get"): + conn = sql.connection.Connection.get(parsed["connection"]) + else: + try: + conn = sql.connection.Connection.set(parsed["connection"]) + # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql + except TypeError: + conn = sql.connection.Connection.set(parsed["connection"], False) + + try: + # A corresponding pgcli object already exists + pgcli = conn._pgcli + _logger.debug("Reusing existing pgcli") + except AttributeError: + # I can't figure out how to get the underylying psycopg2 connection + # from the sqlalchemy connection, so just grab the url and make a + # new connection + pgcli = PGCli() + u = conn.session.engine.url + _logger.debug("New pgcli: %r", str(u)) + + pgcli.connect_uri(str(u._replace(drivername="postgres"))) + conn._pgcli = pgcli + + # For convenience, print the connection alias + print(f"Connected: {conn.name}") + + try: + pgcli.run_cli() + except SystemExit: + pass + + if not pgcli.query_history: + return + + q = pgcli.query_history[-1] + + if not q.successful: + _logger.debug("Unsuccessful query - ignoring") + return + + if q.meta_changed or q.db_changed or q.path_changed: + _logger.debug("Dangerous query detected -- ignoring") + return + + ipython = get_ipython() + return ipython.run_cell_magic("sql", line, q.query) diff --git a/pgcli/main.py b/pgcli/main.py new file mode 100644 index 0000000..f95c800 --- /dev/null +++ b/pgcli/main.py @@ -0,0 +1,1728 @@ +from configobj import ConfigObj, ParseError +from pgspecial.namedqueries import NamedQueries +from .config import skip_initial_comment + +import atexit +import os +import re +import sys +import traceback +import logging +import threading +import shutil +import functools +import pendulum +import datetime as dt +import itertools +import platform +from time import time, sleep +from typing import Optional + +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +from cli_helpers.utils import strip_ansi +from .explain_output_formatter import ExplainOutputFormatter +import click + +try: + import setproctitle +except ImportError: + setproctitle = None +from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle +from prompt_toolkit.document import Document +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.layout.processors import ( + ConditionalProcessor, + HighlightMatchingBracketProcessor, + TabsProcessor, +) +from prompt_toolkit.history import FileHistory +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from pygments.lexers.sql import PostgresLexer + +from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT +import pgspecial as special + +from . import auth +from .pgcompleter import PGCompleter +from .pgtoolbar import create_toolbar_tokens_func +from .pgstyle import style_factory, style_factory_output +from .pgexecute import PGExecute +from .completion_refresher import CompletionRefresher +from .config import ( + get_casing_file, + load_config, + config_location, + ensure_dir_exists, + get_config, + get_config_filename, +) +from .key_bindings import pgcli_bindings +from .packages.formatter.sqlformatter import register_new_formatter +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.parseutils import is_destructive +from .packages.parseutils import parse_destructive_warning +from .__init__ import __version__ + +click.disable_unicode_literals_warning = True + +from urllib.parse import urlparse + +from getpass import getuser + +from psycopg import OperationalError, InterfaceError +from psycopg.conninfo import make_conninfo, conninfo_to_dict + +from collections import namedtuple + +try: + import sshtunnel + + SSH_TUNNEL_SUPPORT = True +except ImportError: + SSH_TUNNEL_SUPPORT = False + + +# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output +COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") +DEFAULT_MAX_FIELD_WIDTH = 500 + +# Query tuples are used for maintaining history +MetaQuery = namedtuple( + "Query", + [ + "query", # The entire text of the command + "successful", # True If all subqueries were successful + "total_time", # Time elapsed executing the query and formatting results + "execution_time", # Time elapsed executing the query + "meta_changed", # True if any subquery executed create/alter/drop + "db_changed", # True if any subquery changed the database + "path_changed", # True if any subquery changed the search path + "mutated", # True if any subquery executed insert/update/delete + "is_special", # True if the query is a special command + ], +) +MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False) + +OutputSettings = namedtuple( + "OutputSettings", + "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output max_field_width", +) +OutputSettings.__new__.__defaults__ = ( + None, + None, + None, + "", + False, + None, + lambda x: x, + None, + DEFAULT_MAX_FIELD_WIDTH, +) + + +class PgCliQuitError(Exception): + pass + + +class PGCli: + default_prompt = "\\u@\\h:\\d> " + max_len_prompt = 30 + + def set_default_pager(self, config): + configured_pager = config["main"].get("pager") + os_environ_pager = os.environ.get("PAGER") + + if configured_pager: + self.logger.info( + 'Default pager found in config file: "%s"', configured_pager + ) + os.environ["PAGER"] = configured_pager + elif os_environ_pager: + self.logger.info( + 'Default pager found in PAGER environment variable: "%s"', + os_environ_pager, + ) + os.environ["PAGER"] = os_environ_pager + else: + self.logger.info( + "No default pager found in environment. Using os default pager" + ) + + # Set default set of less recommended options, if they are not already set. + # They are ignored if pager is different than less. + if not os.environ.get("LESS"): + os.environ["LESS"] = "-SRXF" + + def __init__( + self, + force_passwd_prompt=False, + never_passwd_prompt=False, + pgexecute=None, + pgclirc_file=None, + row_limit=None, + single_connection=False, + less_chatty=None, + prompt=None, + prompt_dsn=None, + auto_vertical_output=False, + warn=None, + ssh_tunnel_url: Optional[str] = None, + ): + self.force_passwd_prompt = force_passwd_prompt + self.never_passwd_prompt = never_passwd_prompt + self.pgexecute = pgexecute + self.dsn_alias = None + self.watch_command = None + + # Load config. + c = self.config = get_config(pgclirc_file) + + # at this point, config should be written to pgclirc_file if it did not exist. Read it. + self.config_writer = load_config(get_config_filename(pgclirc_file)) + + # make sure to use self.config_writer, not self.config + NamedQueries.instance = NamedQueries.from_config(self.config_writer) + + self.logger = logging.getLogger(__name__) + self.initialize_logging() + + self.set_default_pager(c) + self.output_file = None + self.pgspecial = PGSpecial() + + self.explain_mode = False + self.multi_line = c["main"].as_bool("multi_line") + self.multiline_mode = c["main"].get("multi_line_mode", "psql") + self.vi_mode = c["main"].as_bool("vi") + self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.auto_retry_closed_connection = c["main"].as_bool( + "auto_retry_closed_connection" + ) + self.expanded_output = c["main"].as_bool("expand") + self.pgspecial.timing_enabled = c["main"].as_bool("timing") + if row_limit is not None: + self.row_limit = row_limit + else: + self.row_limit = c["main"].as_int("row_limit") + + # if not specified, set to DEFAULT_MAX_FIELD_WIDTH + # if specified but empty, set to None to disable truncation + # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0 + max_field_width = c["main"].get("max_field_width", DEFAULT_MAX_FIELD_WIDTH) + if max_field_width and max_field_width.lower() != "none": + max_field_width = max(3, abs(int(max_field_width))) + else: + max_field_width = None + self.max_field_width = max_field_width + + self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") + self.multiline_continuation_char = c["main"]["multiline_continuation_char"] + self.table_format = c["main"]["table_format"] + self.syntax_style = c["main"]["syntax_style"] + self.cli_style = c["colors"] + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + self.destructive_warning = parse_destructive_warning( + warn or c["main"].as_list("destructive_warning") + ) + self.destructive_warning_restarts_connection = c["main"].as_bool( + "destructive_warning_restarts_connection" + ) + self.destructive_statements_require_transaction = c["main"].as_bool( + "destructive_statements_require_transaction" + ) + + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.null_string = c["main"].get("null_string", "") + self.prompt_format = ( + prompt + if prompt is not None + else c["main"].get("prompt", self.default_prompt) + ) + self.prompt_dsn_format = prompt_dsn + self.on_error = c["main"]["on_error"].upper() + self.decimal_format = c["data_formats"]["decimal"] + self.float_format = c["data_formats"]["float"] + auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger) + self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") + + self.pgspecial.pset_pager( + self.config["main"].as_bool("enable_pager") and "on" or "off" + ) + + self.style_output = style_factory_output(self.syntax_style, c["colors"]) + + self.now = dt.datetime.today() + + self.completion_refresher = CompletionRefresher() + + self.query_history = [] + + # Initialize completer + smart_completion = c["main"].as_bool("smart_completion") + keyword_casing = c["main"]["keyword_casing"] + single_connection = single_connection or c["main"].as_bool( + "always_use_single_connection" + ) + self.settings = { + "casing_file": get_casing_file(c), + "generate_casing_file": c["main"].as_bool("generate_casing_file"), + "generate_aliases": c["main"].as_bool("generate_aliases"), + "asterisk_column_order": c["main"]["asterisk_column_order"], + "qualify_columns": c["main"]["qualify_columns"], + "case_column_headers": c["main"].as_bool("case_column_headers"), + "search_path_filter": c["main"].as_bool("search_path_filter"), + "single_connection": single_connection, + "less_chatty": less_chatty, + "keyword_casing": keyword_casing, + "alias_map_file": c["main"]["alias_map_file"] or None, + } + + completer = PGCompleter( + smart_completion, pgspecial=self.pgspecial, settings=self.settings + ) + self.completer = completer + self._completer_lock = threading.Lock() + self.register_special_commands() + + self.prompt_app = None + + self.ssh_tunnel_config = c.get("ssh tunnels") + self.ssh_tunnel_url = ssh_tunnel_url + self.ssh_tunnel = None + + # formatter setup + self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + register_new_formatter(self.formatter) + + def quit(self): + raise PgCliQuitError + + def register_special_commands(self): + self.pgspecial.register( + self.change_db, + "\\c", + "\\c[onnect] database_name", + "Change to a new database.", + aliases=("use", "\\connect", "USE"), + ) + + refresh_callback = lambda: self.refresh_completions(persist_priorities="all") + + self.pgspecial.register( + self.quit, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + self.pgspecial.register( + self.quit, + "quit", + "quit", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=("exit",), + ) + self.pgspecial.register( + refresh_callback, + "\\#", + "\\#", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + refresh_callback, + "\\refresh", + "\\refresh", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + self.execute_from_file, "\\i", "\\i filename", "Execute commands from file." + ) + self.pgspecial.register( + self.write_to_file, + "\\o", + "\\o [filename]", + "Send all query results to file.", + ) + self.pgspecial.register( + self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" + ) + self.pgspecial.register( + self.change_table_format, + "\\T", + "\\T [format]", + "Change the table format used to output results", + ) + + self.pgspecial.register( + self.echo, + "\\echo", + "\\echo [string]", + "Echo a string to stdout", + ) + + self.pgspecial.register( + self.echo, + "\\qecho", + "\\qecho [string]", + "Echo a string to the query output channel.", + ) + + def echo(self, pattern, **_): + return [(None, None, None, pattern)] + + def change_table_format(self, pattern, **_): + try: + if pattern not in TabularOutputFormatter().supported_formats: + raise ValueError() + self.table_format = pattern + yield (None, None, None, f"Changed table format to {pattern}") + except ValueError: + msg = f"Table format {pattern} not recognized. Allowed formats:" + for table_type in TabularOutputFormatter().supported_formats: + msg += f"\n\t{table_type}" + msg += "\nCurrently set to: %s" % self.table_format + yield (None, None, None, msg) + + def info_connection(self, **_): + if self.pgexecute.host.startswith("/"): + host = 'socket "%s"' % self.pgexecute.host + else: + host = 'host "%s"' % self.pgexecute.host + + yield ( + None, + None, + None, + 'You are connected to database "%s" as user ' + '"%s" on %s at port "%s".' + % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port), + ) + + def change_db(self, pattern, **_): + if pattern: + # Get all the parameters in pattern, handling double quotes if any. + infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern) + # Now removing quotes. + list(map(lambda s: s.strip('"'), infos)) + + infos.extend([None] * (4 - len(infos))) + db, user, host, port = infos + try: + self.pgexecute.connect( + database=db, + user=user, + host=host, + port=port, + **self.pgexecute.extra_args, + ) + except OperationalError as e: + click.secho(str(e), err=True, fg="red") + click.echo("Previous connection kept") + else: + self.pgexecute.connect() + + yield ( + None, + None, + None, + 'You are now connected to database "%s" as ' + 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user), + ) + + def execute_from_file(self, pattern, **_): + if not pattern: + message = "\\i: missing required argument" + return [(None, None, None, message, "", False, True)] + try: + with open(os.path.expanduser(pattern), encoding="utf-8") as f: + query = f.read() + except OSError as e: + return [(None, None, None, str(e), "", False, True)] + + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(query, self.destructive_warning) + ): + message = "Destructive statements must be run within a transaction. Command execution stopped." + return [(None, None, None, message)] + destroy = confirm_destructive_query( + query, self.destructive_warning, self.dsn_alias + ) + if destroy is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] + + on_error_resume = self.on_error == "RESUME" + return self.pgexecute.run( + query, + self.pgspecial, + on_error_resume=on_error_resume, + explain_mode=self.explain_mode, + ) + + def write_to_file(self, pattern, **_): + if not pattern: + self.output_file = None + message = "File output disabled" + return [(None, None, None, message, "", True, True)] + filename = os.path.abspath(os.path.expanduser(pattern)) + if not os.path.isfile(filename): + try: + open(filename, "w").close() + except OSError as e: + self.output_file = None + message = str(e) + "\nFile output disabled" + return [(None, None, None, message, "", False, True)] + self.output_file = filename + message = 'Writing to file "%s"' % self.output_file + return [(None, None, None, message, "", True, True)] + + def initialize_logging(self): + log_file = self.config["main"]["log_file"] + if log_file == "default": + log_file = config_location() + "log" + ensure_dir_exists(log_file) + log_level = self.config["main"]["log_level"] + + # Disable logging if value is NONE by switching to a no-op handler. + # Set log level to a high value so it doesn't even waste cycles getting called. + if log_level.upper() == "NONE": + handler = logging.NullHandler() + else: + handler = logging.FileHandler(os.path.expanduser(log_file)) + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "NONE": logging.CRITICAL, + } + + log_level = level_map[log_level.upper()] + + formatter = logging.Formatter( + "%(asctime)s (%(process)d/%(threadName)s) " + "%(name)s %(levelname)s - %(message)s" + ) + + handler.setFormatter(formatter) + + root_logger = logging.getLogger("pgcli") + root_logger.addHandler(handler) + root_logger.setLevel(log_level) + + root_logger.debug("Initializing pgcli logging.") + root_logger.debug("Log file %r.", log_file) + + pgspecial_logger = logging.getLogger("pgspecial") + pgspecial_logger.addHandler(handler) + pgspecial_logger.setLevel(log_level) + + def connect_dsn(self, dsn, **kwargs): + self.connect(dsn=dsn, **kwargs) + + def connect_service(self, service, user): + service_config, file = parse_service_info(service) + if service_config is None: + click.secho( + f"service '{service}' was not found in {file}", err=True, fg="red" + ) + exit(1) + self.connect( + database=service_config.get("dbname"), + host=service_config.get("host"), + user=user or service_config.get("user"), + port=service_config.get("port"), + passwd=service_config.get("password"), + ) + + def connect_uri(self, uri): + kwargs = conninfo_to_dict(uri) + remap = {"dbname": "database", "password": "passwd"} + kwargs = {remap.get(k, k): v for k, v in kwargs.items()} + self.connect(**kwargs) + + def connect( + self, database="", host="", user="", port="", passwd="", dsn="", **kwargs + ): + # Connect to the database. + + if not user: + user = getuser() + + if not database: + database = user + + kwargs.setdefault("application_name", "pgcli") + + # If password prompt is not forced but no password is provided, try + # getting it from environment variable. + if not self.force_passwd_prompt and not passwd: + passwd = os.environ.get("PGPASSWORD", "") + + # Prompt for a password immediately if requested via the -W flag. This + # avoids wasting time trying to connect to the database and catching a + # no-password exception. + # If we successfully parsed a password from a URI, there's no need to + # prompt for it, even with the -W flag + if self.force_passwd_prompt and not passwd: + passwd = click.prompt( + "Password for %s" % user, hide_input=True, show_default=False, type=str + ) + + key = f"{user}@{host}" + + if not passwd and auth.keyring: + passwd = auth.keyring_get_password(key) + + def should_ask_for_password(exc): + # Prompt for a password after 1st attempt to connect + # fails. Don't prompt if the -w flag is supplied + if self.never_passwd_prompt: + return False + error_msg = exc.args[0] + if "no password supplied" in error_msg: + return True + if "password authentication failed" in error_msg: + return True + return False + + if dsn: + parsed_dsn = conninfo_to_dict(dsn) + if "host" in parsed_dsn: + host = parsed_dsn["host"] + if "port" in parsed_dsn: + port = parsed_dsn["port"] + + if self.ssh_tunnel_config and not self.ssh_tunnel_url: + for db_host_regex, tunnel_url in self.ssh_tunnel_config.items(): + if re.search(db_host_regex, host): + self.ssh_tunnel_url = tunnel_url + break + + if self.ssh_tunnel_url: + # We add the protocol as urlparse doesn't find it by itself + if "://" not in self.ssh_tunnel_url: + self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}" + + tunnel_info = urlparse(self.ssh_tunnel_url) + params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (host, int(port or 5432)), + "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22), + "logger": self.logger, + } + if tunnel_info.username: + params["ssh_username"] = tunnel_info.username + if tunnel_info.password: + params["ssh_password"] = tunnel_info.password + + # Hack: sshtunnel adds a console handler to the logger, so we revert handlers. + # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged. + logger_handlers = self.logger.handlers.copy() + try: + self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params) + self.ssh_tunnel.start() + except Exception as e: + self.logger.handlers = logger_handlers + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + self.logger.handlers = logger_handlers + + atexit.register(self.ssh_tunnel.stop) + host = "127.0.0.1" + port = self.ssh_tunnel.local_bind_ports[0] + + if dsn: + dsn = make_conninfo(dsn, host=host, port=port) + + # Attempt to connect to the database. + # Note that passwd may be empty on the first attempt. If connection + # fails because of a missing or incorrect password, but we're allowed to + # prompt for a password (no -w flag), prompt for a passwd and try again. + try: + try: + pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + except (OperationalError, InterfaceError) as e: + if should_ask_for_password(e): + passwd = click.prompt( + "Password for %s" % user, + hide_input=True, + show_default=False, + type=str, + ) + pgexecute = PGExecute( + database, user, passwd, host, port, dsn, **kwargs + ) + else: + raise e + if passwd and auth.keyring: + auth.keyring_set_password(key, passwd) + + except Exception as e: # Connecting to a database could fail. + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + + self.pgexecute = pgexecute + + def handle_editor_command(self, text): + r""" + Editor command is any query that is prefixed or suffixed + by a '\e'. The reason for a while loop is because a user + might edit a query multiple times. + For eg: + "select * from \e" to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + """ + editor_command = special.editor_command(text) + while editor_command: + if editor_command == "\\e": + filename = special.get_filename(text) + query = special.get_editor_query(text) or self.get_last_query() + else: # \ev or \ef + filename = None + spec = text.split()[1] + if editor_command == "\\ev": + query = self.pgexecute.view_definition(spec) + elif editor_command == "\\ef": + query = self.pgexecute.function_definition(spec) + sql, message = special.open_external_editor(filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + editor_command = special.editor_command(text) + return text + + def execute_command(self, text, handle_closed_connection=True): + logger = self.logger + + query = MetaQuery(query=text, successful=False) + + try: + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(text, self.destructive_warning) + ): + click.secho( + "Destructive statements must be run within a transaction." + ) + raise KeyboardInterrupt + destroy = confirm_destructive_query( + text, self.destructive_warning, self.dsn_alias + ) + if destroy is False: + click.secho("Wise choice!") + raise KeyboardInterrupt + elif destroy: + click.secho("Your call!") + + output, query = self._evaluate_command(text) + except KeyboardInterrupt: + if self.destructive_warning_restarts_connection: + # Restart connection to the database + self.pgexecute.connect() + logger.debug("cancelled query and restarted connection, sql: %r", text) + click.secho( + "cancelled query and restarted connection", err=True, fg="red" + ) + else: + logger.debug("cancelled query, sql: %r", text) + click.secho("cancelled query", err=True, fg="red") + except NotImplementedError: + click.secho("Not Yet Implemented.", fg="yellow") + except OperationalError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + if handle_closed_connection: + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError): + raise + except Exception as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + else: + try: + if self.output_file and not text.startswith( + ("\\o ", "\\? ", "\\echo ") + ): + try: + with open(self.output_file, "a", encoding="utf-8") as f: + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except OSError as e: + click.secho(str(e), err=True, fg="red") + else: + if output: + self.echo_via_pager("\n".join(output)) + except KeyboardInterrupt: + pass + + if self.pgspecial.timing_enabled: + # Only add humanized time display if > 1 second + if query.total_time > 1: + print( + "Time: %0.03fs (%s), executed in: %0.03fs (%s)" + % ( + query.total_time, + pendulum.Duration(seconds=query.total_time).in_words(), + query.execution_time, + pendulum.Duration(seconds=query.execution_time).in_words(), + ) + ) + else: + print("Time: %0.03fs" % query.total_time) + + # Check if we need to update completions, in order of most + # to least drastic changes + if query.db_changed: + with self._completer_lock: + self.completer.reset_completions() + self.refresh_completions(persist_priorities="keywords") + elif query.meta_changed: + self.refresh_completions(persist_priorities="all") + elif query.path_changed: + logger.debug("Refreshing search path") + with self._completer_lock: + self.completer.set_search_path(self.pgexecute.search_path()) + logger.debug("Search path: %r", self.completer.search_path) + return query + + def _check_ongoing_transaction_and_allow_quitting(self): + """Return whether we can really quit, possibly by asking the + user to confirm so if there is an ongoing transaction. + """ + if not self.pgexecute.valid_transaction(): + return True + while 1: + try: + choice = click.prompt( + "A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.", + default="a", + ) + except click.Abort: + # Print newline if user aborts with `^C`, otherwise + # pgcli's prompt will be printed on the same line + # (just after the confirmation prompt). + click.echo(None, err=False) + choice = "a" + choice = choice.lower() + if choice == "a": + return False # do not quit + if choice == "c": + query = self.execute_command("commit") + return query.successful # quit only if query is successful + if choice == "r": + query = self.execute_command("rollback") + return query.successful # quit only if query is successful + + def run_cli(self): + logger = self.logger + + history_file = self.config["main"]["history_file"] + if history_file == "default": + history_file = config_location() + "history" + history = FileHistory(os.path.expanduser(history_file)) + self.refresh_completions(history=history, persist_priorities="none") + + self.prompt_app = self._build_cli(history) + + if not self.less_chatty: + print("Server: PostgreSQL", self.pgexecute.server_version) + print("Version:", __version__) + print("Home: http://pgcli.com") + + try: + while True: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + continue + except EOFError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise + + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + continue + + try: + self.handle_watch_command(text) + except PgCliQuitError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise + + self.now = dt.datetime.today() + + # Allow PGCompleter to learn user's preferred keywords, etc. + with self._completer_lock: + self.completer.extend_query_history(text) + + except (PgCliQuitError, EOFError): + if not self.less_chatty: + print("Goodbye!") + + def handle_watch_command(self, text): + # Initialize default metaquery in case execution fails + self.watch_command, timing = special.get_watch_command(text) + + # If we run \watch without a command, apply it to the last query run. + if self.watch_command is not None and not self.watch_command.strip(): + try: + self.watch_command = self.query_history[-1].query + except IndexError: + click.secho( + "\\watch cannot be used with an empty query", err=True, fg="red" + ) + self.watch_command = None + + # If there's a command to \watch, run it in a loop. + if self.watch_command: + while self.watch_command: + try: + query = self.execute_command(self.watch_command) + click.echo(f"Waiting for {timing} seconds before repeating") + sleep(timing) + except KeyboardInterrupt: + self.watch_command = None + + # Otherwise, execute it as a regular command. + else: + query = self.execute_command(text) + + self.query_history.append(query) + + def _build_cli(self, history): + key_bindings = pgcli_bindings(self) + + def get_message(): + if self.dsn_alias and self.prompt_dsn_format is not None: + prompt_format = self.prompt_dsn_format + else: + prompt_format = self.prompt_format + + prompt = self.get_prompt(prompt_format) + + if ( + prompt_format == self.default_prompt + and len(prompt) > self.max_len_prompt + ): + prompt = self.get_prompt("\\d> ") + + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) + + def get_continuation(width, line_number, is_soft_wrap): + continuation = self.multiline_continuation_char * (width - 1) + " " + return [("class:continuation", continuation)] + + get_toolbar_tokens = create_toolbar_tokens_func(self) + + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with self._completer_lock: + prompt_app = PromptSession( + lexer=PygmentsLexer(PostgresLexer), + reserve_space_for_menu=self.min_num_menu_lines, + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None, + complete_style=complete_style, + input_processors=[ + # Highlight matching brackets while editing. + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars="[](){}"), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ), + # Render \t as 4 spaces instead of "^I" + TabsProcessor(char1=" ", char2=" "), + ], + auto_suggest=AutoSuggestFromHistory(), + tempfile_suffix=".sql", + # N.b. pgcli's multi-line mode controls submit-on-Enter (which + # overrides the default behaviour of prompt_toolkit) and is + # distinct from prompt_toolkit's multiline mode here, which + # controls layout/display of the prompt/buffer + multiline=True, + history=history, + completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)), + complete_while_typing=True, + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS, + search_ignore_case=True, + ) + + return prompt_app + + def _should_limit_output(self, sql, cur): + """returns True if the output should be truncated, False otherwise.""" + if self.explain_mode: + return False + if not is_select(sql): + return False + + return ( + not self._has_limit(sql) + and self.row_limit != 0 + and cur + and cur.rowcount > self.row_limit + ) + + def _has_limit(self, sql): + if not sql: + return False + return "limit " in sql.lower() + + def _limit_output(self, cur): + limit = min(self.row_limit, cur.rowcount) + new_cur = itertools.islice(cur, limit) + new_status = "SELECT " + str(limit) + click.secho("The result was limited to %s rows" % limit, fg="red") + + return new_cur, new_status + + def _evaluate_command(self, text): + """Used to run a command entered by the user during CLI operation + (Puts the E in REPL) + + returns (results, MetaQuery) + """ + logger = self.logger + logger.debug("sql: %r", text) + + # set query to formatter in order to parse table name + self.formatter.query = text + all_success = True + meta_changed = False # CREATE, ALTER, DROP, etc + mutated = False # INSERT, DELETE, etc + db_changed = False + path_changed = False + output = [] + total = 0 + execution = 0 + + # Run the query. + start = time() + on_error_resume = self.on_error == "RESUME" + res = self.pgexecute.run( + text, + self.pgspecial, + exception_formatter, + on_error_resume, + explain_mode=self.explain_mode, + ) + + is_special = None + + for title, cur, headers, status, sql, success, is_special in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + + if self._should_limit_output(sql, cur): + cur, status = self._limit_output(cur) + + if self.pgspecial.auto_expand or self.auto_expand: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + expanded = self.pgspecial.expanded_output or self.expanded_output + settings = OutputSettings( + table_format=self.table_format, + dcmlfmt=self.decimal_format, + floatfmt=self.float_format, + missingval=self.null_string, + expanded=expanded, + max_width=max_width, + case_function=( + self.completer.case + if self.settings["case_column_headers"] + else lambda x: x + ), + style_output=self.style_output, + max_field_width=self.max_field_width, + ) + execution = time() - start + formatted = format_output( + title, cur, headers, status, settings, self.explain_mode + ) + + output.extend(formatted) + total = time() - start + + # Keep track of whether any of the queries are mutating or changing + # the database + if success: + mutated = mutated or is_mutating(status) + db_changed = db_changed or has_change_db_cmd(sql) + meta_changed = meta_changed or has_meta_cmd(sql) + path_changed = path_changed or has_change_path_cmd(sql) + else: + all_success = False + + meta_query = MetaQuery( + text, + all_success, + total, + execution, + meta_changed, + db_changed, + path_changed, + mutated, + is_special, + ) + + return output, meta_query + + def _handle_server_closed_connection(self, text): + """Used during CLI execution.""" + try: + click.secho("Reconnecting...", fg="green") + self.pgexecute.connect() + click.secho("Reconnected!", fg="green") + except OperationalError as e: + click.secho("Reconnect Failed", fg="red") + click.secho(str(e), err=True, fg="red") + else: + retry = self.auto_retry_closed_connection or confirm( + "Run the query from before reconnecting?" + ) + if retry: + click.secho("Running query...", fg="green") + # Don't get stuck in a retry loop + self.execute_command(text, handle_closed_connection=False) + + def refresh_completions(self, history=None, persist_priorities="all"): + """Refresh outdated completions + + :param history: A prompt_toolkit.history.FileHistory object. Used to + load keyword and identifier preferences + + :param persist_priorities: 'all' or 'keywords' + """ + + callback = functools.partial( + self._on_completions_refreshed, persist_priorities=persist_priorities + ) + return self.completion_refresher.refresh( + self.pgexecute, + self.pgspecial, + callback, + history=history, + settings=self.settings, + ) + + def _on_completions_refreshed(self, new_completer, persist_priorities): + self._swap_completer_objects(new_completer, persist_priorities) + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() + + def _swap_completer_objects(self, new_completer, persist_priorities): + """Swap the completer object with the newly created completer. + + persist_priorities is a string specifying how the old completer's + learned prioritizer should be transferred to the new completer. + + 'none' - The new prioritizer is left in a new/clean state + + 'all' - The new prioritizer is updated to exactly reflect + the old one + + 'keywords' - The new prioritizer is updated with old keyword + priorities, but not any other. + + """ + with self._completer_lock: + old_completer = self.completer + self.completer = new_completer + + if persist_priorities == "all": + # Just swap over the entire prioritizer + new_completer.prioritizer = old_completer.prioritizer + elif persist_priorities == "keywords": + # Swap over the entire prioritizer, but clear name priorities, + # leaving learned keyword priorities alone + new_completer.prioritizer = old_completer.prioritizer + new_completer.prioritizer.clear_names() + elif persist_priorities == "none": + # Leave the new prioritizer as is + pass + self.completer = new_completer + + def get_completions(self, text, cursor_positition): + with self._completer_lock: + return self.completer.get_completions( + Document(text=text, cursor_position=cursor_positition), None + ) + + def get_prompt(self, string): + # should be before replacing \\d + string = string.replace("\\dsn_alias", self.dsn_alias or "") + string = string.replace("\\t", self.now.strftime("%x %X")) + string = string.replace("\\u", self.pgexecute.user or "(none)") + string = string.replace("\\H", self.pgexecute.host or "(none)") + string = string.replace("\\h", self.pgexecute.short_host or "(none)") + string = string.replace("\\d", self.pgexecute.dbname or "(none)") + string = string.replace( + "\\p", + str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", + ) + string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") + string = string.replace("\\#", "#" if self.pgexecute.superuser else ">") + string = string.replace("\\n", "\n") + return string + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + + def is_too_wide(self, line): + """Will this line be too wide to fit into terminal?""" + if not self.prompt_app: + return False + return ( + len(COLOR_CODE_REGEX.sub("", line)) + > self.prompt_app.output.get_size().columns + ) + + def is_too_tall(self, lines): + """Are there too many lines to fit into terminal?""" + if not self.prompt_app: + return False + return len(lines) >= (self.prompt_app.output.get_size().rows - 4) + + def echo_via_pager(self, text, color=None): + if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: + click.echo(text, color=color) + elif ( + self.pgspecial.pager_config == PAGER_LONG_OUTPUT + and self.table_format != "csv" + ): + lines = text.split("\n") + + # The last 4 lines are reserved for the pgcli menu and padding + if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines): + click.echo_via_pager(text, color=color) + else: + click.echo(text, color=color) + else: + click.echo_via_pager(text, color) + + +@click.command() +# Default host is '' so psycopg can default to either localhost or unix socket +@click.option( + "-h", + "--host", + default="", + envvar="PGHOST", + help="Host address of the postgres database.", +) +@click.option( + "-p", + "--port", + default=5432, + help="Port number at which the " "postgres instance is listening.", + envvar="PGPORT", + type=click.INT, +) +@click.option( + "-U", + "--username", + "username_opt", + help="Username to connect to the postgres database.", +) +@click.option( + "-u", "--user", "username_opt", help="Username to connect to the postgres database." +) +@click.option( + "-W", + "--password", + "prompt_passwd", + is_flag=True, + default=False, + help="Force password prompt.", +) +@click.option( + "-w", + "--no-password", + "never_prompt", + is_flag=True, + default=False, + help="Never prompt for password.", +) +@click.option( + "--single-connection", + "single_connection", + is_flag=True, + default=False, + help="Do not use a separate connection for completions.", +) +@click.option("-v", "--version", is_flag=True, help="Version of pgcli.") +@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.") +@click.option( + "--pgclirc", + default=config_location() + "config", + envvar="PGCLIRC", + help="Location of pgclirc file.", + type=click.Path(dir_okay=False), +) +@click.option( + "-D", + "--dsn", + default="", + envvar="DSN", + help="Use DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--list-dsn", + "list_dsn", + is_flag=True, + help="list of DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--row-limit", + default=None, + envvar="PGROWLIMIT", + type=click.INT, + help="Set threshold for row limit prompt. Use 0 to disable prompt.", +) +@click.option( + "--less-chatty", + "less_chatty", + is_flag=True, + default=False, + help="Skip intro on startup and goodbye on exit.", +) +@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").') +@click.option( + "--prompt-dsn", + help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").', +) +@click.option( + "-l", + "--list", + "list_databases", + is_flag=True, + help="list available databases, then exit.", +) +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option( + "--warn", + default=None, + help="Warn before running a destructive query.", +) +@click.option( + "--ssh-tunnel", + default=None, + help="Open an SSH tunnel to the given address and connect to the database from it.", +) +@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) +@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) +def cli( + dbname, + username_opt, + host, + port, + prompt_passwd, + never_prompt, + single_connection, + dbname_opt, + username, + version, + pgclirc, + dsn, + row_limit, + less_chatty, + prompt, + prompt_dsn, + list_databases, + auto_vertical_output, + list_dsn, + warn, + ssh_tunnel: str, +): + if version: + print("Version:", __version__) + sys.exit(0) + + config_dir = os.path.dirname(config_location()) + if not os.path.exists(config_dir): + os.makedirs(config_dir) + + # Migrate the config file from old location. + config_full_path = config_location() + "config" + if os.path.exists(os.path.expanduser("~/.pgclirc")): + if not os.path.exists(config_full_path): + shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path) + print("Config file (~/.pgclirc) moved to new location", config_full_path) + else: + print("Config file is now located at", config_full_path) + print( + "Please move the existing config file ~/.pgclirc to", + config_full_path, + ) + if list_dsn: + try: + cfg = load_config(pgclirc, config_full_path) + for alias in cfg["alias_dsn"]: + click.secho(alias + " : " + cfg["alias_dsn"][alias]) + sys.exit(0) + except Exception as err: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + + if ssh_tunnel and not SSH_TUNNEL_SUPPORT: + click.secho( + 'Cannot open SSH tunnel, "sshtunnel" package was not found. ' + "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.", + err=True, + fg="red", + ) + exit(1) + + pgcli = PGCli( + prompt_passwd, + never_prompt, + pgclirc_file=pgclirc, + row_limit=row_limit, + single_connection=single_connection, + less_chatty=less_chatty, + prompt=prompt, + prompt_dsn=prompt_dsn, + auto_vertical_output=auto_vertical_output, + warn=warn, + ssh_tunnel_url=ssh_tunnel, + ) + + # Choose which ever one has a valid value. + if dbname_opt and dbname: + # work as psql: when database is given as option and argument use the argument as user + username = dbname + database = dbname_opt or dbname or "" + user = username_opt or username + service = None + if database.startswith("service="): + service = database[8:] + elif os.getenv("PGSERVICE") is not None: + service = os.getenv("PGSERVICE") + # because option --list or -l are not supposed to have a db name + if list_databases: + database = "postgres" + + if dsn != "": + try: + cfg = load_config(pgclirc, config_full_path) + dsn_config = cfg["alias_dsn"][dsn] + except KeyError: + click.secho( + f"Could not find a DSN with alias {dsn}. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + except Exception: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + pgcli.connect_uri(dsn_config) + pgcli.dsn_alias = dsn + elif "://" in database: + pgcli.connect_uri(database) + elif "=" in database and service is None: + pgcli.connect_dsn(database, user=user) + elif service is not None: + pgcli.connect_service(service, user) + else: + pgcli.connect(database, host, user, port) + + if list_databases: + cur, headers, status = pgcli.pgexecute.full_databases() + + title = "List of databases" + settings = OutputSettings(table_format="ascii", missingval="") + formatted = format_output(title, cur, headers, status, settings) + pgcli.echo_via_pager("\n".join(formatted)) + + sys.exit(0) + + pgcli.logger.debug( + "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", + database, + user, + host, + port, + ) + + if setproctitle: + obfuscate_process_password() + + pgcli.run_cli() + + +def obfuscate_process_password(): + process_title = setproctitle.getproctitle() + if "://" in process_title: + process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title) + elif "=" in process_title: + process_title = re.sub( + r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title + ) + + setproctitle.setproctitle(process_title) + + +def has_meta_cmd(query): + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop, commit or rollback.""" + try: + first_token = query.split()[0] + if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"): + return True + except Exception: + return False + + return False + + +def has_change_db_cmd(query): + """Determines if the statement is a database switch such as 'use' or '\\c'""" + try: + first_token = query.split()[0] + if first_token.lower() in ("use", "\\c", "\\connect"): + return True + except Exception: + return False + + return False + + +def has_change_path_cmd(sql): + """Determines if the search_path should be refreshed by checking if the + sql has 'set search_path'.""" + return "set search_path" in sql.lower() + + +def is_mutating(status): + """Determines if the statement is mutating based on the status.""" + if not status: + return False + + mutating = {"insert", "update", "delete"} + return status.split(None, 1)[0].lower() in mutating + + +def is_select(status): + """Returns true if the first word in status is 'select'.""" + if not status: + return False + return status.split(None, 1)[0].lower() == "select" + + +def exception_formatter(e): + return click.style(str(e), fg="red") + + +def format_output(title, cur, headers, status, settings, explain_mode=False): + output = [] + expanded = settings.expanded or settings.table_format == "vertical" + table_format = "vertical" if settings.expanded else settings.table_format + max_width = settings.max_width + case_function = settings.case_function + if explain_mode: + formatter = ExplainOutputFormatter(max_width or 100) + else: + formatter = TabularOutputFormatter(format_name=table_format) + + def format_array(val): + if val is None: + return settings.missingval + if not isinstance(val, list): + return val + return "{" + ",".join(str(format_array(e)) for e in val) + "}" + + def format_arrays(data, headers, **_): + data = list(data) + for row in data: + row[:] = [ + format_array(val) if isinstance(val, list) else val for val in row + ] + + return data, headers + + def format_status(cur, status): + # redshift does not return rowcount as part of status. + # See https://github.com/dbcli/pgcli/issues/1320 + if cur and hasattr(cur, "rowcount") and cur.rowcount is not None: + if status and not status.endswith(str(cur.rowcount)): + status += " %s" % cur.rowcount + return status + + output_kwargs = { + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": settings.missingval, + "integer_format": settings.dcmlfmt, + "float_format": settings.floatfmt, + "preprocessors": (format_numbers, format_arrays), + "disable_numparse": True, + "preserve_whitespace": True, + "style": settings.style_output, + "max_field_width": settings.max_field_width, + } + if not settings.floatfmt: + output_kwargs["preprocessors"] = (align_decimals,) + + if table_format == "csv": + # The default CSV dialect is "excel" which is not handling newline values correctly + # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n' + # as the line terminator + # https://github.com/dbcli/pgcli/issues/1102 + dialect = "excel" if platform.system() == "Windows" else "unix" + output_kwargs["dialect"] = dialect + + if title: # Only print the title if it's not None. + output.append(title) + + if cur: + headers = [case_function(x) for x in headers] + if max_width is not None: + cur = list(cur) + column_types = None + if hasattr(cur, "description"): + column_types = [] + for d in cur.description: + col_type = cur.adapters.types.get(d.type_code) + type_name = col_type.name if col_type else None + if type_name in ("numeric", "float4", "float8"): + column_types.append(float) + if type_name in ("int2", "int4", "int8"): + column_types.append(int) + else: + column_types.append(str) + + formatted = formatter.format_output(cur, headers, **output_kwargs) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + first_line = next(formatted) + formatted = itertools.chain([first_line], formatted) + if ( + not explain_mode + and not expanded + and max_width + and len(strip_ansi(first_line)) > max_width + and headers + ): + formatted = formatter.format_output( + cur, + headers, + format_name="vertical", + column_types=column_types, + **output_kwargs, + ) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + + output = itertools.chain(output, formatted) + + # Only print the status if it's not None + if status: + output = itertools.chain(output, [format_status(cur, status)]) + + return output + + +def parse_service_info(service): + service = service or os.getenv("PGSERVICE") + service_file = os.getenv("PGSERVICEFILE") + if not service_file: + # try ~/.pg_service.conf (if that exists) + if platform.system() == "Windows": + service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf" + elif os.getenv("PGSYSCONFDIR"): + service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") + else: + service_file = os.path.expanduser("~/.pg_service.conf") + if not service or not os.path.exists(service_file): + # nothing to do + return None, service_file + with open(service_file, newline="") as f: + skipped_lines = skip_initial_comment(f) + try: + service_file_config = ConfigObj(f) + except ParseError as err: + err.line_number += skipped_lines + raise err + if service not in service_file_config: + return None, service_file + service_conf = service_file_config.get(service) + return service_conf, service_file + + +if __name__ == "__main__": + cli() diff --git a/pgcli/packages/__init__.py b/pgcli/packages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgcli/packages/formatter/__init__.py b/pgcli/packages/formatter/__init__.py new file mode 100644 index 0000000..9bad579 --- /dev/null +++ b/pgcli/packages/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/pgcli/packages/formatter/sqlformatter.py b/pgcli/packages/formatter/sqlformatter.py new file mode 100644 index 0000000..5224eff --- /dev/null +++ b/pgcli/packages/formatter/sqlformatter.py @@ -0,0 +1,74 @@ +# coding=utf-8 + +from pgcli.packages.parseutils.tables import extract_tables + + +supported_formats = ( + "sql-insert", + "sql-update", + "sql-update-1", + "sql-update-2", +) + +preprocessors = () + + +def escape_for_sql_statement(value): + if value is None: + return "NULL" + + if isinstance(value, bytes): + return f"X'{value.hex()}'" + + return "'{}'".format(value) + + +def adapter(data, headers, table_format=None, **kwargs): + tables = extract_tables(formatter.query) + if len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = "DUAL" + if table_format == "sql-insert": + h = '", "'.join(headers) + yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith("sql-update"): + s = table_format.split("-") + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield 'UPDATE "{}" SET'.format(table_name) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield '{}"{}" = {}'.format( + prefix, headers[i], escape_for_sql_statement(v) + ) + if prefix == " ": + prefix = ", " + f = '"{}" = {}' + where = ( + f.format(headers[i], escape_for_sql_statement(d[i])) + for i in range(keys) + ) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter): + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + TabularOutputFormatter.register_new_formatter( + sql_format, adapter, preprocessors, {"table_format": sql_format} + ) diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py new file mode 100644 index 0000000..023e13b --- /dev/null +++ b/pgcli/packages/parseutils/__init__.py @@ -0,0 +1,58 @@ +import sqlparse + + +BASE_KEYWORDS = [ + "drop", + "shutdown", + "delete", + "truncate", + "alter", + "unconditional_update", +] +ALL_KEYWORDS = BASE_KEYWORDS + ["update"] + + +def query_starts_with(formatted_sql, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def query_is_unconditional_update(formatted_sql): + """Check if the query starts with UPDATE and contains no WHERE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" and "where" not in tokens + + +def is_destructive(queries, keywords): + """Returns if any of the queries in *queries* is destructive.""" + for query in sqlparse.split(queries): + if query: + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + if "unconditional_update" in keywords and query_is_unconditional_update( + formatted_sql + ): + return True + if query_starts_with(formatted_sql, keywords): + return True + return False + + +def parse_destructive_warning(warning_level): + """Converts a deprecated destructive warning option to a list of command keywords.""" + if not warning_level: + return [] + + if not isinstance(warning_level, list): + if "," in warning_level: + return warning_level.split(",") + warning_level = [warning_level] + + return { + "true": ALL_KEYWORDS, + "false": [], + "all": ALL_KEYWORDS, + "moderate": BASE_KEYWORDS, + "off": [], + "": [], + }.get(warning_level[0], warning_level) diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py new file mode 100644 index 0000000..e1f9088 --- /dev/null +++ b/pgcli/packages/parseutils/ctes.py @@ -0,0 +1,141 @@ +from sqlparse import parse +from sqlparse.tokens import Keyword, CTE, DML +from sqlparse.sql import Identifier, IdentifierList, Parenthesis +from collections import namedtuple +from .meta import TableMetadata, ColumnMetadata + + +# TableExpression is a namedtuple representing a CTE, used internally +# name: cte alias assigned in the query +# columns: list of column names +# start: index into the original string of the left parens starting the CTE +# stop: index into the original string of the right parens ending the CTE +TableExpression = namedtuple("TableExpression", "name columns start stop") + + +def isolate_query_ctes(full_text, text_before_cursor): + """Simplify a query by converting CTEs into table metadata objects""" + + if not full_text or not full_text.strip(): + return full_text, text_before_cursor, tuple() + + ctes, remainder = extract_ctes(full_text) + if not ctes: + return full_text, text_before_cursor, () + + current_position = len(text_before_cursor) + meta = [] + + for cte in ctes: + if cte.start < current_position < cte.stop: + # Currently editing a cte - treat its body as the current full_text + text_before_cursor = full_text[cte.start : current_position] + full_text = full_text[cte.start : cte.stop] + return full_text, text_before_cursor, meta + + # Append this cte to the list of available table metadata + cols = (ColumnMetadata(name, None, ()) for name in cte.columns) + meta.append(TableMetadata(cte.name, cols)) + + # Editing past the last cte (ie the main body of the query) + full_text = full_text[ctes[-1].stop :] + text_before_cursor = text_before_cursor[ctes[-1].stop : current_position] + + return full_text, text_before_cursor, tuple(meta) + + +def extract_ctes(sql): + """Extract constant table expresseions from a query + + Returns tuple (ctes, remainder_sql) + + ctes is a list of TableExpression namedtuples + remainder_sql is the text from the original query after the CTEs have + been stripped. + """ + + p = parse(sql)[0] + + # Make sure the first meaningful token is "WITH" which is necessary to + # define CTEs + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return [], sql + + # Get the next (meaningful) token, which should be the first CTE + idx, tok = p.token_next(idx) + if not tok: + return ([], "") + start_pos = token_start_pos(p.tokens, idx) + ctes = [] + + if isinstance(tok, IdentifierList): + # Multiple ctes + for t in tok.get_identifiers(): + cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) + cte = get_cte_from_token(t, start_pos + cte_start_offset) + if not cte: + continue + ctes.append(cte) + elif isinstance(tok, Identifier): + # A single CTE + cte = get_cte_from_token(tok, start_pos) + if cte: + ctes.append(cte) + + idx = p.token_index(tok) + 1 + + # Collapse everything after the ctes into a remainder query + remainder = "".join(str(tok) for tok in p.tokens[idx:]) + + return ctes, remainder + + +def get_cte_from_token(tok, pos0): + cte_name = tok.get_real_name() + if not cte_name: + return None + + # Find the start position of the opening parens enclosing the cte body + idx, parens = tok.token_next_by(Parenthesis) + if not parens: + return None + + start_pos = pos0 + token_start_pos(tok.tokens, idx) + cte_len = len(str(parens)) # includes parens + stop_pos = start_pos + cte_len + + column_names = extract_column_names(parens) + + return TableExpression(cte_name, column_names, start_pos, stop_pos) + + +def extract_column_names(parsed): + # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ("insert", "update", "delete"): + # Jump ahead to the RETURNING clause where the list of column names is + idx, tok = parsed.token_next_by(idx, (Keyword, "returning")) + elif not tok_val == "select": + # Must be invalid CTE + return () + + # The next token should be either a column name, or a list of column names + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in _identifiers(tok)) + + +def token_start_pos(tokens, idx): + return sum(len(str(t)) for t in tokens[:idx]) + + +def _identifiers(tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + # NB: IdentifierList.get_identifiers() can return non-identifiers! + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py new file mode 100644 index 0000000..333cab5 --- /dev/null +++ b/pgcli/packages/parseutils/meta.py @@ -0,0 +1,170 @@ +from collections import namedtuple + +_ColumnMetadata = namedtuple( + "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"] +) + + +def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False): + return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default) + + +ForeignKey = namedtuple( + "ForeignKey", + [ + "parentschema", + "parenttable", + "parentcolumn", + "childschema", + "childtable", + "childcolumn", + ], +) +TableMetadata = namedtuple("TableMetadata", "name columns") + + +def parse_defaults(defaults_string): + """Yields default values for a function, given the string provided by + pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" + if not defaults_string: + return + current = "" + in_quote = None + for char in defaults_string: + if current == "" and char == " ": + # Skip space after comma separating default expressions + continue + if char == '"' or char == "'": + if in_quote and char == in_quote: + # End quote + in_quote = None + elif not in_quote: + # Begin quote + in_quote = char + elif char == "," and not in_quote: + # End of expression + yield current + current = "" + continue + current += char + yield current + + +class FunctionMetadata: + def __init__( + self, + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, + ): + """Class for describing a postgresql function""" + + self.schema_name = schema_name + self.func_name = func_name + + self.arg_modes = tuple(arg_modes) if arg_modes else None + self.arg_names = tuple(arg_names) if arg_names else None + + # Be flexible in not requiring arg_types -- use None as a placeholder + # for each arg. (Used for compatibility with old versions of postgresql + # where such info is hard to get. + if arg_types: + self.arg_types = tuple(arg_types) + elif arg_modes: + self.arg_types = tuple([None] * len(arg_modes)) + elif arg_names: + self.arg_types = tuple([None] * len(arg_names)) + else: + self.arg_types = None + + self.arg_defaults = tuple(parse_defaults(arg_defaults)) + + self.return_type = return_type.strip() + self.is_aggregate = is_aggregate + self.is_window = is_window + self.is_set_returning = is_set_returning + self.is_extension = bool(is_extension) + self.is_public = self.schema_name and self.schema_name == "public" + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + def _signature(self): + return ( + self.schema_name, + self.func_name, + self.arg_names, + self.arg_types, + self.arg_modes, + self.return_type, + self.is_aggregate, + self.is_window, + self.is_set_returning, + self.is_extension, + self.arg_defaults, + ) + + def __hash__(self): + return hash(self._signature()) + + def __repr__(self): + return ( + "%s(schema_name=%r, func_name=%r, arg_names=%r, " + "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, " + "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)" + ) % ((self.__class__.__name__,) + self._signature()) + + def has_variadic(self): + return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes) + + def args(self): + """Returns a list of input-parameter ColumnMetadata namedtuples.""" + if not self.arg_names: + return [] + modes = self.arg_modes or ["i"] * len(self.arg_names) + args = [ + (name, typ) + for name, typ, mode in zip(self.arg_names, self.arg_types, modes) + if mode in ("i", "b", "v") # IN, INOUT, VARIADIC + ] + + def arg(name, typ, num): + num_args = len(args) + num_defaults = len(self.arg_defaults) + has_default = num + num_defaults >= num_args + default = ( + self.arg_defaults[num - num_args + num_defaults] + if has_default + else None + ) + return ColumnMetadata(name, typ, [], default, has_default) + + return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] + + def fields(self): + """Returns a list of output-field ColumnMetadata namedtuples""" + + if self.return_type.lower() == "void": + return [] + elif not self.arg_modes: + # For functions without output parameters, the function name + # is used as the name of the output column. + # E.g. 'SELECT unnest FROM unnest(...);' + return [ColumnMetadata(self.func_name, self.return_type, [])] + + return [ + ColumnMetadata(name, typ, []) + for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes) + if mode in ("o", "b", "t") + ] # OUT, INOUT, TABLE diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py new file mode 100644 index 0000000..9098115 --- /dev/null +++ b/pgcli/packages/parseutils/tables.py @@ -0,0 +1,165 @@ +import sqlparse +from collections import namedtuple +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +TableReference = namedtuple( + "TableReference", ["schema", "name", "alias", "is_function"] +) +TableReference.ref = property( + lambda self: self.alias + or ( + self.name + if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"' + ) +) + + +# This code is borrowed from sqlparse example script. +# +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", + ): + return True + return False + + +def _identifier_is_function(identifier): + return any(isinstance(t, Function) for t in identifier.tokens) + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + yield from extract_from_part(item, stop_at_punctuation) + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # `return`. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) + ): + tbl_prefix_seen = False + else: + yield item + elif item.ttype is Keyword or item.ttype is Keyword.DML: + item_val = item.value.upper() + if item_val in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) or item_val.endswith("JOIN"): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream, allow_functions=True): + """yields tuples of TableReference namedtuples""" + + # We need to do some massaging of the names because postgres is case- + # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is) + def parse_identifier(item): + name = item.get_real_name() + schema_name = item.get_parent_name() + alias = item.get_alias() + if not name: + schema_name = None + name = item.get_name() + alias = alias or name + schema_quoted = schema_name and item.value[0] == '"' + if schema_name and not schema_quoted: + schema_name = schema_name.lower() + quote_count = item.value.count('"') + name_quoted = quote_count > 2 or (quote_count and not schema_quoted) + alias_quoted = alias and item.value[-1] == '"' + if alias_quoted or name_quoted and not alias and name.islower(): + alias = '"' + (alias or name) + '"' + if name and not name_quoted and not name.islower(): + if not alias: + alias = name + name = name.lower() + return schema_name, name, alias + + try: + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = allow_functions and _identifier_is_function( + identifier + ) + except AttributeError: + continue + if real_name: + yield TableReference( + schema_name, real_name, identifier.get_alias(), is_function + ) + elif isinstance(item, Identifier): + schema_name, real_name, alias = parse_identifier(item) + is_function = allow_functions and _identifier_is_function(item) + + yield TableReference(schema_name, real_name, alias, is_function) + elif isinstance(item, Function): + schema_name, real_name, alias = parse_identifier(item) + yield TableReference(None, real_name, alias, allow_functions) + except StopIteration: + return + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statement. + + Returns a list of TableReference namedtuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return () + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == "insert" + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + + # Kludge: sqlparse mistakenly identifies insert statements as + # function calls due to the parenthesized column list, e.g. interprets + # "insert into foo (bar, baz)" as a function call to foo with arguments + # (bar, baz). So don't allow any identifiers in insert statements + # to have is_function=True + identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt) + # In the case 'sche.', we get an empty TableReference; remove that + return tuple(i for i in identifiers if i.name) diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py new file mode 100644 index 0000000..034c96e --- /dev/null +++ b/pgcli/packages/parseutils/utils.py @@ -0,0 +1,140 @@ +import re +import sqlparse +from sqlparse.sql import Identifier +from sqlparse.tokens import Token, Error + +cleanup_regex = { + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile(r"([^\s]+)$"), +} + + +def last_word(text, include="alphanum_underscore"): + r""" + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + >>> last_word('"foo*bar', include='most_punctuations') + '"foo*bar' + """ + + if not text: # Empty string + return "" + + if text[-1].isspace(): + return "" + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return "" + + +def find_prev_keyword(sql, n_skip=0): + """Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, "" + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + flattened = flattened[: len(flattened) - n_skip] + + logical_operators = ("AND", "OR", "NOT", "BETWEEN") + + for t in reversed(flattened): + if t.value == "(" or ( + t.is_keyword and (t.value.upper() not in logical_operators) + ): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index throws an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = "".join(tok.value for tok in flattened[: idx + 1]) + return t, text + + return None, "" + + +# Postgresql dollar quote signs look like `$$` or `$tag$` +dollar_quote_regex = re.compile(r"^\$[^$]*\$$") + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +def _parsed_is_open_quote(parsed): + # Look for unmatched single quotes, or unmatched dollar sign quotes + return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten()) + + +def parse_partial_identifier(word): + """Attempt to parse a (partially typed) word as an identifier + + word may include a schema qualification, like `schema_name.partial_name` + or `schema_name.` There may also be unclosed quotation marks, like + `"schema`, or `schema."partial_name` + + :param word: string representing a (partially complete) identifier + :return: sqlparse.sql.Identifier, or None + """ + + p = sqlparse.parse(word)[0] + n_tok = len(p.tokens) + if n_tok == 1 and isinstance(p.tokens[0], Identifier): + return p.tokens[0] + elif p.token_next_by(m=(Error, '"'))[1]: + # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' + # Close the double quote, then reparse + return parse_partial_identifier(word + '"') + else: + return None diff --git a/pgcli/packages/pgliterals/__init__.py b/pgcli/packages/pgliterals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py new file mode 100644 index 0000000..5c39296 --- /dev/null +++ b/pgcli/packages/pgliterals/main.py @@ -0,0 +1,15 @@ +import os +import json + +root = os.path.dirname(__file__) +literal_file = os.path.join(root, "pgliterals.json") + +with open(literal_file) as f: + literals = json.load(f) + + +def get_literals(literal_type, type_=tuple): + # Where `literal_type` is one of 'keywords', 'functions', 'datatypes', + # returns a tuple of literal values of that type. + + return type_(literals[literal_type]) diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json new file mode 100644 index 0000000..df00817 --- /dev/null +++ b/pgcli/packages/pgliterals/pgliterals.json @@ -0,0 +1,630 @@ +{ + "keywords": { + "ACCESS": [], + "ADD": [], + "ALL": [], + "ALTER": [ + "AGGREGATE", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DEFAULT", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "LARGE OBJECT", + "MATERIALIZED VIEW", + "OPERATOR", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "SYSTEM", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "USER", + "VIEW" + ], + "AND": [], + "ANY": [], + "AS": [], + "ASC": [], + "AUDIT": [], + "BEGIN": [], + "BETWEEN": [], + "BY": [], + "CASE": [], + "CHAR": [], + "CHECK": [], + "CLUSTER": [], + "COLUMN": [], + "COMMENT": [], + "COMMIT": [], + "COMPRESS": [], + "CONCURRENTLY": [], + "CONNECT": [], + "COPY": [], + "CREATE": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN EXTENSION", + "FUNCTION", + "GLOBAL", + "GROUP", + "IF NOT EXISTS", + "INDEX", + "LANGUAGE", + "LOCAL", + "MATERIALIZED VIEW", + "OPERATOR", + "OR REPLACE", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEMPORARY", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "UNIQUE", + "UNLOGGED", + "USER", + "USER MAPPING", + "VIEW" + ], + "CURRENT": [], + "DATABASE": [], + "DATE": [], + "DECIMAL": [], + "DEFAULT": [], + "DELETE FROM": [], + "DELIMITER": [], + "DESC": [], + "DESCRIBE": [], + "DISTINCT": [], + "DROP": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN TABLE", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "MATERIALIZED VIEW", + "OPERATOR", + "OWNED", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRANSFORM", + "TRIGGER", + "TYPE", + "USER", + "USER MAPPING", + "VIEW" + ], + "EXPLAIN": [], + "ELSE": [], + "ENCODING": [], + "ESCAPE": [], + "EXCLUSIVE": [], + "EXISTS": [], + "EXTENSION": [], + "FILE": [], + "FLOAT": [], + "FOR": [], + "FORMAT": [], + "FORCE_QUOTE": [], + "FORCE_NOT_NULL": [], + "FREEZE": [], + "FROM": [], + "FULL": [], + "FUNCTION": [], + "GRANT": [], + "GROUP BY": [], + "HAVING": [], + "HEADER": [], + "IDENTIFIED": [], + "IMMEDIATE": [], + "IN": [], + "INCREMENT": [], + "INDEX": [], + "INITIAL": [], + "INSERT INTO": [], + "INTEGER": [], + "INTERSECT": [], + "INTERVAL": [], + "INTO": [], + "IS": [], + "JOIN": [], + "LANGUAGE": [], + "LEFT": [], + "LEVEL": [], + "LIKE": [], + "LIMIT": [], + "LOCK": [], + "LONG": [], + "MATERIALIZED VIEW": [], + "MAXEXTENTS": [], + "MINUS": [], + "MLSLABEL": [], + "MODE": [], + "MODIFY": [], + "NOT": [], + "NOAUDIT": [], + "NOTICE": [], + "NOCOMPRESS": [], + "NOWAIT": [], + "NULL": [], + "NUMBER": [], + "OIDS": [], + "OF": [], + "OFFLINE": [], + "ON": [], + "ONLINE": [], + "OPTION": [], + "OR": [], + "ORDER BY": [], + "OUTER": [], + "OWNER": [], + "PCTFREE": [], + "PRIMARY": [], + "PRIOR": [], + "PRIVILEGES": [], + "QUOTE": [], + "RAISE": [], + "RENAME": [], + "REPLACE": [], + "RESET": ["ALL"], + "RAW": [], + "REFRESH MATERIALIZED VIEW": [], + "RESOURCE": [], + "RETURNS": [], + "REVOKE": [], + "RIGHT": [], + "ROLLBACK": [], + "ROW": [], + "ROWID": [], + "ROWNUM": [], + "ROWS": [], + "SELECT": [], + "SESSION": [], + "SET": [], + "SHARE": [], + "SHOW": [], + "SIZE": [], + "SMALLINT": [], + "START": [], + "SUCCESSFUL": [], + "SYNONYM": [], + "SYSDATE": [], + "TABLE": [], + "TEMPLATE": [], + "THEN": [], + "TO": [], + "TRIGGER": [], + "TRUNCATE": [], + "UID": [], + "UNION": [], + "UNIQUE": [], + "UPDATE": [], + "USE": [], + "USER": [], + "USING": [], + "VALIDATE": [], + "VALUES": [], + "VARCHAR": [], + "VARCHAR2": [], + "VIEW": [], + "WHEN": [], + "WHENEVER": [], + "WHERE": [], + "WITH": [] + }, + "functions": [ + "ABBREV", + "ABS", + "AGE", + "AREA", + "ARRAY_AGG", + "ARRAY_APPEND", + "ARRAY_CAT", + "ARRAY_DIMS", + "ARRAY_FILL", + "ARRAY_LENGTH", + "ARRAY_LOWER", + "ARRAY_NDIMS", + "ARRAY_POSITION", + "ARRAY_POSITIONS", + "ARRAY_PREPEND", + "ARRAY_REMOVE", + "ARRAY_REPLACE", + "ARRAY_TO_STRING", + "ARRAY_UPPER", + "ASCII", + "AVG", + "BIT_AND", + "BIT_LENGTH", + "BIT_OR", + "BOOL_AND", + "BOOL_OR", + "BOUND_BOX", + "BOX", + "BROADCAST", + "BTRIM", + "CARDINALITY", + "CBRT", + "CEIL", + "CEILING", + "CENTER", + "CHAR_LENGTH", + "CHR", + "CIRCLE", + "CLOCK_TIMESTAMP", + "CONCAT", + "CONCAT_WS", + "CONVERT", + "CONVERT_FROM", + "CONVERT_TO", + "COUNT", + "CUME_DIST", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATE_PART", + "DATE_TRUNC", + "DECODE", + "DEGREES", + "DENSE_RANK", + "DIAMETER", + "DIV", + "ENCODE", + "ENUM_FIRST", + "ENUM_LAST", + "ENUM_RANGE", + "EVERY", + "EXP", + "EXTRACT", + "FAMILY", + "FIRST_VALUE", + "FLOOR", + "FORMAT", + "GET_BIT", + "GET_BYTE", + "HEIGHT", + "HOST", + "HOSTMASK", + "INET_MERGE", + "INET_SAME_FAMILY", + "INITCAP", + "ISCLOSED", + "ISFINITE", + "ISOPEN", + "JUSTIFY_DAYS", + "JUSTIFY_HOURS", + "JUSTIFY_INTERVAL", + "LAG", + "LAST_VALUE", + "LEAD", + "LEFT", + "LENGTH", + "LINE", + "LN", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOG", + "LOG10", + "LOWER", + "LPAD", + "LSEG", + "LTRIM", + "MAKE_DATE", + "MAKE_INTERVAL", + "MAKE_TIME", + "MAKE_TIMESTAMP", + "MAKE_TIMESTAMPTZ", + "MASKLEN", + "MAX", + "MD5", + "MIN", + "MOD", + "NETMASK", + "NETWORK", + "NOW", + "NPOINTS", + "NTH_VALUE", + "NTILE", + "NUM_NONNULLS", + "NUM_NULLS", + "OCTET_LENGTH", + "OVERLAY", + "PARSE_IDENT", + "PATH", + "PCLOSE", + "PERCENT_RANK", + "PG_CLIENT_ENCODING", + "PI", + "POINT", + "POLYGON", + "POPEN", + "POSITION", + "POWER", + "QUOTE_IDENT", + "QUOTE_LITERAL", + "QUOTE_NULLABLE", + "RADIANS", + "RADIUS", + "RANDOM", + "RANK", + "REGEXP_MATCH", + "REGEXP_MATCHES", + "REGEXP_REPLACE", + "REGEXP_SPLIT_TO_ARRAY", + "REGEXP_SPLIT_TO_TABLE", + "REPEAT", + "REPLACE", + "REVERSE", + "RIGHT", + "ROUND", + "ROW_NUMBER", + "RPAD", + "RTRIM", + "SCALE", + "SET_BIT", + "SET_BYTE", + "SET_MASKLEN", + "SHA224", + "SHA256", + "SHA384", + "SHA512", + "SIGN", + "SPLIT_PART", + "SQRT", + "STARTS_WITH", + "STATEMENT_TIMESTAMP", + "STRING_TO_ARRAY", + "STRPOS", + "SUBSTR", + "SUBSTRING", + "SUM", + "TEXT", + "TIMEOFDAY", + "TO_ASCII", + "TO_CHAR", + "TO_DATE", + "TO_HEX", + "TO_NUMBER", + "TO_TIMESTAMP", + "TRANSACTION_TIMESTAMP", + "TRANSLATE", + "TRIM", + "TRUNC", + "UNNEST", + "UPPER", + "WIDTH", + "WIDTH_BUCKET", + "XMLAGG" + ], + "datatypes": [ + "ANY", + "ANYARRAY", + "ANYELEMENT", + "ANYENUM", + "ANYNONARRAY", + "ANYRANGE", + "BIGINT", + "BIGSERIAL", + "BIT", + "BIT VARYING", + "BOOL", + "BOOLEAN", + "BOX", + "BYTEA", + "CHAR", + "CHARACTER", + "CHARACTER VARYING", + "CIDR", + "CIRCLE", + "CSTRING", + "DATE", + "DECIMAL", + "DOUBLE PRECISION", + "EVENT_TRIGGER", + "FDW_HANDLER", + "FLOAT4", + "FLOAT8", + "INET", + "INT", + "INT2", + "INT4", + "INT8", + "INTEGER", + "INTERNAL", + "INTERVAL", + "JSON", + "JSONB", + "LANGUAGE_HANDLER", + "LINE", + "LSEG", + "MACADDR", + "MACADDR8", + "MONEY", + "NUMERIC", + "OID", + "OPAQUE", + "PATH", + "PG_LSN", + "POINT", + "POLYGON", + "REAL", + "RECORD", + "REGCLASS", + "REGCONFIG", + "REGDICTIONARY", + "REGNAMESPACE", + "REGOPER", + "REGOPERATOR", + "REGPROC", + "REGPROCEDURE", + "REGROLE", + "REGTYPE", + "SERIAL", + "SERIAL2", + "SERIAL4", + "SERIAL8", + "SMALLINT", + "SMALLSERIAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TRIGGER", + "TSQUERY", + "TSVECTOR", + "TXID_SNAPSHOT", + "UUID", + "VARBIT", + "VARCHAR", + "VOID", + "XML" + ], + "reserved": [ + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONSTRAINT", + "CREATE", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "GRANT", + "GROUP", + "HAVING", + "IN", + "INITIALLY", + "INTERSECT", + "INTO", + "LATERAL", + "LEADING", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NOT", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "SELECT", + "SESSION_USER", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "AUTHORIZATION", + "BINARY", + "COLLATION", + "CONCURRENTLY", + "CROSS", + "CURRENT_SCHEMA", + "FREEZE", + "FULL", + "ILIKE", + "INNER", + "IS", + "ISNULL", + "JOIN", + "LEFT", + "LIKE", + "NATURAL", + "NOTNULL", + "OUTER", + "OVERLAPS", + "RIGHT", + "SIMILAR", + "TABLESAMPLE", + "VERBOSE" + ] +} diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py new file mode 100644 index 0000000..f5a9cb5 --- /dev/null +++ b/pgcli/packages/prioritization.py @@ -0,0 +1,51 @@ +import re +import sqlparse +from sqlparse.tokens import Name +from collections import defaultdict +from .pgliterals.main import get_literals + + +white_space_regex = re.compile("\\s+", re.MULTILINE) + + +def _compile_regex(keyword): + # Surround the keyword with word boundaries and replace interior whitespace + # with whitespace wildcards + pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b" + return re.compile(pattern, re.MULTILINE | re.IGNORECASE) + + +keywords = get_literals("keywords") +keyword_regexs = {kw: _compile_regex(kw) for kw in keywords} + + +class PrevalenceCounter: + def __init__(self): + self.keyword_counts = defaultdict(int) + self.name_counts = defaultdict(int) + + def update(self, text): + self.update_keywords(text) + self.update_names(text) + + def update_names(self, text): + for parsed in sqlparse.parse(text): + for token in parsed.flatten(): + if token.ttype in Name: + self.name_counts[token.value] += 1 + + def clear_names(self): + self.name_counts = defaultdict(int) + + def update_keywords(self, text): + # Count keywords. Can't rely for sqlparse for this, because it's + # database agnostic + for keyword, regex in keyword_regexs.items(): + for _ in regex.finditer(text): + self.keyword_counts[keyword] += 1 + + def keyword_count(self, keyword): + return self.keyword_counts[keyword] + + def name_count(self, name): + return self.name_counts[name] diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py new file mode 100644 index 0000000..997b86e --- /dev/null +++ b/pgcli/packages/prompt_utils.py @@ -0,0 +1,37 @@ +import sys +import click +from .parseutils import is_destructive + + +def confirm_destructive_query(queries, keywords, alias): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + info = "You're about to run a destructive command" + if alias: + info += f" in {click.style(alias, fg='red')}" + + prompt_text = f"{info}.\nDo you want to proceed?" + if is_destructive(queries, keywords) and sys.stdin.isatty(): + return confirm(prompt_text) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py new file mode 100644 index 0000000..b78edd6 --- /dev/null +++ b/pgcli/packages/sqlcompletion.py @@ -0,0 +1,605 @@ +import sys +import re +import sqlparse +from collections import namedtuple +from sqlparse.sql import Comparison, Identifier, Where +from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier +from .parseutils.tables import extract_tables +from .parseutils.ctes import isolate_query_ctes +from pgspecial.main import parse_special_command + + +Special = namedtuple("Special", []) +Database = namedtuple("Database", []) +Schema = namedtuple("Schema", ["quoted"]) +Schema.__new__.__defaults__ = (False,) +# FromClauseItem is a table/view/function used in the FROM clause +# `table_refs` contains the list of tables/... already in the statement, +# used to ensure that the alias we suggest is unique +FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables") +Table = namedtuple("Table", ["schema", "table_refs", "local_tables"]) +TableFormat = namedtuple("TableFormat", []) +View = namedtuple("View", ["schema", "table_refs"]) +# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' +JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"]) +# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' +Join = namedtuple("Join", ["table_refs", "schema"]) + +Function = namedtuple("Function", ["schema", "table_refs", "usage"]) +# For convenience, don't require the `usage` argument in Function constructor +Function.__new__.__defaults__ = (None, tuple(), None) +Table.__new__.__defaults__ = (None, tuple(), tuple()) +View.__new__.__defaults__ = (None, tuple()) +FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) + +Column = namedtuple( + "Column", + ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], +) +Column.__new__.__defaults__ = (None, None, tuple(), False, None) + +Keyword = namedtuple("Keyword", ["last_token"]) +Keyword.__new__.__defaults__ = (None,) +NamedQuery = namedtuple("NamedQuery", []) +Datatype = namedtuple("Datatype", ["schema"]) +Alias = namedtuple("Alias", ["aliases"]) + +Path = namedtuple("Path", []) + + +class SqlStatement: + def __init__(self, full_text, text_before_cursor): + self.identifier = None + self.word_before_cursor = word_before_cursor = last_word( + text_before_cursor, include="many_punctuations" + ) + full_text = _strip_named_query(full_text) + text_before_cursor = _strip_named_query(text_before_cursor) + + full_text, text_before_cursor, self.local_tables = isolate_query_ctes( + full_text, text_before_cursor + ) + + self.text_before_cursor_including_last_word = text_before_cursor + + # If we've partially typed a word then word_before_cursor won't be an + # empty string. In that case we want to remove the partially typed + # string before sending it to the sqlparser. Otherwise the last token + # will always be the partially typed string which renders the smart + # completion useless because it will always return the list of + # keywords as completion. + if self.word_before_cursor: + if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\": + parsed = sqlparse.parse(text_before_cursor) + else: + text_before_cursor = text_before_cursor[: -len(word_before_cursor)] + parsed = sqlparse.parse(text_before_cursor) + self.identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + full_text, text_before_cursor, parsed = _split_multiple_statements( + full_text, text_before_cursor, parsed + ) + + self.full_text = full_text + self.text_before_cursor = text_before_cursor + self.parsed = parsed + + self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or "" + + def is_insert(self): + return self.parsed.token_first().value.lower() == "insert" + + def get_tables(self, scope="full"): + """Gets the tables available in the statement. + param `scope:` possible values: 'full', 'insert', 'before' + If 'insert', only the first table is returned. + If 'before', only tables before the cursor are returned. + If not 'insert' and the stmt is an insert, the first table is skipped. + """ + tables = extract_tables( + self.full_text if scope == "full" else self.text_before_cursor + ) + if scope == "insert": + tables = tables[:1] + elif self.is_insert(): + tables = tables[1:] + return tables + + def get_previous_token(self, token): + return self.parsed.token_prev(self.parsed.token_index(token))[1] + + def get_identifier_schema(self): + schema = (self.identifier and self.identifier.get_parent_name()) or None + # If schema name is unquoted, lower-case it + if schema and self.identifier.value[0] != '"': + schema = schema.lower() + + return schema + + def reduce_to_prev_keyword(self, n_skip=0): + prev_keyword, self.text_before_cursor = find_prev_keyword( + self.text_before_cursor, n_skip=n_skip + ) + return prev_keyword + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + if full_text.startswith("\\i "): + return (Path(),) + + # This is a temporary hack; the exception handling + # here should be removed once sqlparse has been fixed + try: + stmt = SqlStatement(full_text, text_before_cursor) + except (TypeError, AttributeError): + return [] + + # Check for special commands and handle those separately + if stmt.parsed: + # Be careful here because trivial whitespace is parsed as a + # statement, but the statement won't have a first token + tok1 = stmt.parsed.token_first() + if tok1 and tok1.value.startswith("\\"): + text = stmt.text_before_cursor + stmt.word_before_cursor + return suggest_special(text) + + return suggest_based_on_last_token(stmt.last_token, stmt) + + +named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+") + + +def _strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + """ + + if named_query_regex.match(txt): + txt = named_query_regex.sub("", txt) + return txt + + +function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M) + + +def _find_function_body(text): + split = function_body_pattern.search(text) + return (split.start(2), split.end(2)) if split else (None, None) + + +def _statement_from_function(full_text, text_before_cursor, statement): + current_pos = len(text_before_cursor) + body_start, body_end = _find_function_body(full_text) + if body_start is None: + return full_text, text_before_cursor, statement + if not body_start <= current_pos < body_end: + return full_text, text_before_cursor, statement + full_text = full_text[body_start:body_end] + text_before_cursor = text_before_cursor[body_start:] + parsed = sqlparse.parse(text_before_cursor) + return _split_multiple_statements(full_text, text_before_cursor, parsed) + + +def _split_multiple_statements(full_text, text_before_cursor, parsed): + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds + # the current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(str(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + return full_text, text_before_cursor, None + + token2 = None + if statement.get_type() in ("CREATE", "CREATE OR REPLACE"): + token1 = statement.token_first() + if token1: + token1_idx = statement.token_index(token1) + token2 = statement.token_next(token1_idx)[1] + if token2 and token2.value.upper() == "FUNCTION": + full_text, text_before_cursor, statement = _statement_from_function( + full_text, text_before_cursor, statement + ) + return full_text, text_before_cursor, statement + + +SPECIALS_SUGGESTION = { + "dT": Datatype, + "df": Function, + "dt": Table, + "dv": View, + "sf": Function, +} + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return (Special(),) + + if cmd in ("\\c", "\\connect"): + return (Database(),) + + if cmd == "\\T": + return (TableFormat(),) + + if cmd == "\\dn": + return (Schema(),) + + if arg: + # Try to distinguish "\d name" from "\d schema.name" + # Note that this will fail to obtain a schema name if wildcards are + # used, e.g. "\d schema???.name" + parsed = sqlparse.parse(arg)[0].tokens[0] + try: + schema = parsed.get_parent_name() + except AttributeError: + schema = None + else: + schema = None + + if cmd[1:] == "d": + # \d can describe tables or views + if schema: + return (Table(schema=schema), View(schema=schema)) + else: + return (Schema(), Table(schema=None), View(schema=None)) + elif cmd[1:] in SPECIALS_SUGGESTION: + rel_type = SPECIALS_SUGGESTION[cmd[1:]] + if schema: + if rel_type == Function: + return (Function(schema=schema, usage="special"),) + return (rel_type(schema=schema),) + else: + if rel_type == Function: + return (Schema(), Function(schema=None, usage="special")) + return (Schema(), rel_type(schema=None)) + + if cmd in ["\\n", "\\ns", "\\nd"]: + return (NamedQuery(),) + + return (Keyword(), Special()) + + +def suggest_based_on_last_token(token, stmt): + if isinstance(token, str): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword = stmt.reduce_to_prev_keyword() + return suggest_based_on_last_token(prev_keyword, stmt) + elif isinstance(token, Identifier): + # If the previous token is an identifier, we can suggest datatypes if + # we're in a parenthesized column/field list, e.g.: + # CREATE TABLE foo (Identifier + # CREATE FUNCTION foo (Identifier + # If we're not in a parenthesized list, the most likely scenario is the + # user is about to specify an alias, e.g.: + # SELECT Identifier + # SELECT foo FROM Identifier + prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) + if prev_keyword and prev_keyword.value == "(": + # Suggest datatypes + return suggest_based_on_last_token("type", stmt) + else: + return (Keyword(),) + else: + token_v = token.value.lower() + + if not token: + return (Keyword(), Special()) + elif token_v.endswith("("): + p = sqlparse.parse(stmt.text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token("where", stmt) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + prev_tok = where.token_prev(len(where.tokens) - 1)[1] + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == "exists": + return (Keyword(),) + else: + return column_suggestions + + # Get the token before the parens + prev_tok = p.token_prev(len(p.tokens) - 1)[1] + + if ( + prev_tok + and prev_tok.value + and prev_tok.value.lower().split(" ")[-1] == "using" + ): + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = stmt.get_tables("before") + + # suggest columns that are present in more than one table + return ( + Column( + table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables, + ), + ) + + elif p.token_first().value.lower() == "select": + # If the lparen is preceded by a space chances are we're about to + # do a sub-select. + if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): + return (Keyword(),) + prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] + if prev_prev_tok and prev_prev_tok.normalized == "INTO": + return (Column(table_refs=stmt.get_tables("insert"), context="insert"),) + # We're probably in a function argument list + return _suggest_expression(token_v, stmt) + elif token_v == "set": + return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) + elif token_v in ("select", "where", "having", "order by", "distinct"): + return _suggest_expression(token_v, stmt) + elif token_v == "as": + # Don't suggest anything for aliases + return () + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate") + ): + schema = stmt.get_identifier_schema() + tables = extract_tables(stmt.text_before_cursor) + is_join = token_v.endswith("join") and token.is_keyword + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + if token_v == "from" or is_join: + suggest.append( + FromClauseItem( + schema=schema, table_refs=tables, local_tables=stmt.local_tables + ) + ) + elif token_v == "truncate": + suggest.append(Table(schema)) + else: + suggest.extend((Table(schema), View(schema))) + + if is_join and _allow_join(stmt.parsed): + tables = stmt.get_tables("before") + suggest.append(Join(table_refs=tables, schema=schema)) + + return tuple(suggest) + + elif token_v == "function": + schema = stmt.get_identifier_schema() + + # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` + try: + prev = stmt.get_previous_token(token).value.lower() + if prev in ("drop", "alter", "create", "create or replace"): + # Suggest functions from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + suggest.append(Function(schema=schema, usage="signature")) + return tuple(suggest) + + except ValueError: + pass + return tuple() + + elif token_v in ("table", "view"): + # E.g. 'ALTER TABLE ' + rel_type = {"table": Table, "view": View, "function": Function}[token_v] + schema = stmt.get_identifier_schema() + if schema: + return (rel_type(schema=schema),) + else: + return (Schema(), rel_type(schema=schema)) + + elif token_v == "column": + # E.g. 'ALTER TABLE foo ALTER COLUMN bar + return (Column(table_refs=stmt.get_tables()),) + + elif token_v == "on": + tables = stmt.get_tables("before") + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None + if parent: + # "ON parent." + # parent can be either a schema name or table alias + filteredtables = tuple(t for t in tables if identifies(parent, t)) + sugs = [ + Column(table_refs=filteredtables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ] + if filteredtables and _allow_join_condition(stmt.parsed): + sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) + return tuple(sugs) + else: + # ON + # Use table alias if there is one, otherwise the table name + aliases = tuple(t.ref for t in tables) + if _allow_join_condition(stmt.parsed): + return ( + Alias(aliases=aliases), + JoinCondition(table_refs=tables, parent=None), + ) + else: + return (Alias(aliases=aliases),) + + elif token_v in ("c", "use", "database", "template"): + # "\c ", "DROP DATABASE ", + # "CREATE DATABASE WITH TEMPLATE " + return (Database(),) + elif token_v == "schema": + # DROP SCHEMA schema_name, SET SCHEMA schema name + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) + quoted = prev_keyword and prev_keyword.value.lower() == "set" + return (Schema(quoted),) + elif token_v.endswith(",") or token_v in ("=", "and", "or"): + prev_keyword = stmt.reduce_to_prev_keyword() + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return () + elif token_v in ("type", "::"): + # ALTER TABLE foo SET DATA TYPE bar + # SELECT foo::bar + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well + schema = stmt.get_identifier_schema() + suggestions = [Datatype(schema=schema), Table(schema=schema)] + if not schema: + suggestions.append(Schema()) + return tuple(suggestions) + elif token_v in {"alter", "create", "drop"}: + return (Keyword(token_v.upper()),) + elif token.is_keyword: + # token is a keyword we haven't implemented any special handling for + # go backwards in the query until we find one we do recognize + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return (Keyword(token_v.upper()),) + else: + return (Keyword(),) + + +def _suggest_expression(token_v, stmt): + """ + Return suggestions for an expression, taking account of any partially-typed + identifier's parent, which may be a table alias or schema name. + """ + parent = stmt.identifier.get_parent_name() if stmt.identifier else [] + tables = stmt.get_tables() + + if parent: + tables = tuple(t for t in tables if identifies(parent, t)) + return ( + Column(table_refs=tables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ) + + return ( + Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True), + Function(schema=None), + Keyword(token_v.upper()), + ) + + +def identifies(id, ref): + """Returns true if string `id` matches TableReference `ref`""" + + return ( + id == ref.alias + or id == ref.name + or (ref.schema and (id == ref.schema + "." + ref.name)) + ) + + +def _allow_join_condition(statement): + """ + Tests if a join condition should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b on a.id = + So check that the preceding token is a ON, AND, or OR keyword, instead of + e.g. an equals sign. + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower() in ("on", "and", "or") + + +def _allow_join(statement): + """ + Tests if a join should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b + So check that the preceding token is a JOIN keyword + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in ( + "cross join", + "natural join", + ) diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py new file mode 100644 index 0000000..c236c13 --- /dev/null +++ b/pgcli/pgbuffer.py @@ -0,0 +1,61 @@ +import logging + +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition +from prompt_toolkit.application import get_app +from .packages.parseutils.utils import is_open_quote + +_logger = logging.getLogger(__name__) + + +def _is_complete(sql): + # A complete command is an sql statement that ends with a semicolon, unless + # there's an open quote surrounding it, as is common when writing a + # CREATE FUNCTION command + return sql.endswith(";") and not is_open_quote(sql) + + +""" +Returns True if the buffer contents should be handled (i.e. the query/command +executed) immediately. This is necessary as we use prompt_toolkit in multiline +mode, which by default will insert new lines on Enter. +""" + + +def safe_multi_line_mode(pgcli): + @Condition + def cond(): + _logger.debug( + 'Multi-line mode state: "%s" / "%s"', pgcli.multi_line, pgcli.multiline_mode + ) + return pgcli.multi_line and (pgcli.multiline_mode == "safe") + + return cond + + +def buffer_should_be_handled(pgcli): + @Condition + def cond(): + if not pgcli.multi_line: + _logger.debug("Not in multi-line mode. Handle the buffer.") + return True + + if pgcli.multiline_mode == "safe": + _logger.debug("Multi-line mode is set to 'safe'. Do NOT handle the buffer.") + return False + + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + text = doc.text.strip() + + return ( + text.startswith("\\") # Special Command + or text.endswith(r"\e") # Special Command + or text.endswith(r"\G") # Ended with \e which should launch the editor + or _is_complete(text) # A complete SQL command + or (text == "exit") # Exit doesn't need semi-colon + or (text == "quit") # Quit doesn't need semi-colon + or (text == ":q") # To all the vim fans out there + or (text == "") # Just a plain enter without any text + ) + + return cond diff --git a/pgcli/pgclirc b/pgcli/pgclirc new file mode 100644 index 0000000..51f7eae --- /dev/null +++ b/pgcli/pgclirc @@ -0,0 +1,236 @@ +# vi: ft=dosini +[main] + +# Enables context sensitive auto-completion. If this is disabled, all +# possible completions will be listed. +smart_completion = True + +# Display the completions in several columns. (More completions will be +# visible.) +wider_completion_menu = False + +# Do not create new connections for refreshing completions; Equivalent to +# always running with the --single-connection flag. +always_use_single_connection = False + +# Multi-line mode allows breaking up the sql statements into multiple lines. If +# this is set to True, then the end of the statements must have a semi-colon. +# If this is set to False then sql statements can't be split into multiple +# lines. End of line (return) is considered as the end of the statement. +multi_line = False + +# If multi_line_mode is set to "psql", in multi-line mode, [Enter] will execute +# the current input if the input ends in a semicolon. +# If multi_line_mode is set to "safe", in multi-line mode, [Enter] will always +# insert a newline, and [Esc] [Enter] or [Alt]-[Enter] must be used to execute +# a command. +multi_line_mode = psql + +# Destructive warning will alert you before executing a sql statement +# that may cause harm to the database such as "drop table", "drop database", +# "shutdown", "delete", or "update". +# You can pass a list of destructive commands or leave it empty if you want to skip all warnings. +# "unconditional_update" will warn you of update statements that don't have a where clause +destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update + +# Destructive warning can restart the connection if this is enabled and the +# user declines. This means that any current uncommitted transaction can be +# aborted if the user doesn't want to proceed with a destructive_warning +# statement. +destructive_warning_restarts_connection = False + +# When this option is on (and if `destructive_warning` is not empty), +# destructive statements are not executed when outside of a transaction. +destructive_statements_require_transaction = False + +# Enables expand mode, which is similar to `\x` in psql. +expand = False + +# Enables auto expand mode, which is similar to `\x auto` in psql. +auto_expand = False + +# Auto-retry queries on connection failures and other operational errors. If +# False, will prompt to rerun the failed query instead of auto-retrying. +auto_retry_closed_connection = True + +# If set to True, table suggestions will include a table alias +generate_aliases = False + +# Path to a json file that specifies specific table aliases to use when generate_aliases is set to True +# the format for this file should be: +# { +# "some_table_name": "desired_alias", +# "some_other_table_name": "another_alias" +# } +alias_map_file = + +# log_file location. +# In Unix/Linux: ~/.config/pgcli/log +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log +# %USERPROFILE% is typically C:\Users\{username} +log_file = default + +# keyword casing preference. Possible values: "lower", "upper", "auto" +keyword_casing = auto + +# casing_file location. +# In Unix/Linux: ~/.config/pgcli/casing +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\casing +# %USERPROFILE% is typically C:\Users\{username} +casing_file = default + +# If generate_casing_file is set to True and there is no file in the above +# location, one will be generated based on usage in SQL/PLPGSQL functions. +generate_casing_file = False + +# Casing of column headers based on the casing_file described above +case_column_headers = True + +# history_file location. +# In Unix/Linux: ~/.config/pgcli/history +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\history +# %USERPROFILE% is typically C:\Users\{username} +history_file = default + +# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" +# and "DEBUG". "NONE" disables logging. +log_level = INFO + +# Order of columns when expanding * to column list +# Possible values: "table_order" and "alphabetic" +asterisk_column_order = table_order + +# Whether to qualify with table alias/name when suggesting columns +# Possible values: "always", "never" and "if_more_than_one_table" +qualify_columns = if_more_than_one_table + +# When no schema is entered, only suggest objects in search_path +search_path_filter = False + +# Default pager. See https://www.pgcli.com/pager for more information on settings. +# By default 'PAGER' environment variable is used. If the pager is less, and the 'LESS' +# environment variable is not set, then LESS='-SRXF' will be automatically set. +# pager = less + +# Timing of sql statements and table rendering. +timing = True + +# Show/hide the informational toolbar with function keymap at the footer. +show_bottom_toolbar = True + +# Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe, +# ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs, +# textile, moinmoin, jira, vertical, tsv, csv, sql-insert, sql-update, +# sql-update-1, sql-update-2 (formatter with sql-* prefix can format query +# output to executable insertion or updating sql). +# Recommended: psql, fancy_grid and grid. +table_format = psql + +# Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt, +# native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark, +# colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity +syntax_style = default + +# Keybindings: +# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. +# When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E +# for end are available in the REPL. +vi = False + +# Error handling +# When one of multiple SQL statements causes an error, choose to either +# continue executing the remaining statements, or stopping +# Possible values "STOP" or "RESUME" +on_error = STOP + +# Set threshold for row limit. Use 0 to disable limiting. +row_limit = 1000 + +# Truncate long text fields to this value for tabular display (does not apply to csv). +# Leave unset to disable truncation. Example: "max_field_width = " +# Be aware that formatting might get slow with values larger than 500 and tables with +# lots of records. +max_field_width = 500 + +# Skip intro on startup and goodbye on exit +less_chatty = False + +# Postgres prompt +# \t - Current date and time +# \u - Username +# \h - Short hostname of the server (up to first '.') +# \H - Hostname of the server +# \d - Database name +# \p - Database port +# \i - Postgres PID +# \# - "@" sign if logged in as superuser, '>' in other case +# \n - Newline +# \dsn_alias - name of dsn connection string alias if -D option is used (empty otherwise) +# \x1b[...m - insert ANSI escape sequence +# eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>' +prompt = '\u@\h:\d> ' + +# Number of lines to reserve for the suggestion menu +min_num_menu_lines = 4 + +# Character used to left pad multi-line queries to match the prompt size. +multiline_continuation_char = '' + +# The string used in place of a null value. +null_string = '' + +# manage pager on startup +enable_pager = True + +# Use keyring to automatically save and load password in a secure manner +keyring = True + +# Custom colors for the completion menu, toolbar, etc. +[colors] +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' +selected = '#ffffff bg:#6666aa' +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' +# These three values can be used to further refine the syntax highlighting. +# They are commented out by default, since they have priority over the theme set +# with the `syntax_style` setting and overriding its behavior can be confusing. +# literal.string = '#ba2121' +# literal.number = '#666666' +# keyword = 'bold #008000' + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" +output.null = "#808080" + +# Named queries are queries you can execute by name. +[named queries] + +# Here's where you can provide a list of connection string aliases. +# You can use it by passing the -D option. `pgcli -D example_dsn` +[alias_dsn] +# example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname] + +# Format for number representation +# for decimal "d" - 12345678, ",d" - 12,345,678 +# for float "g" - 123456.78, ",g" - 123,456.78 +[data_formats] +decimal = "" +float = "" diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py new file mode 100644 index 0000000..17fc540 --- /dev/null +++ b/pgcli/pgcompleter.py @@ -0,0 +1,1072 @@ +import json +import logging +import re +from itertools import count, repeat, chain +import operator +from collections import namedtuple, defaultdict, OrderedDict +from cli_helpers.tabular_output import TabularOutputFormatter +from pgspecial.namedqueries import NamedQueries +from prompt_toolkit.completion import Completer, Completion, PathCompleter +from prompt_toolkit.document import Document +from .packages.sqlcompletion import ( + FromClauseItem, + suggest_type, + Special, + Database, + Schema, + Table, + TableFormat, + Function, + Column, + View, + Keyword, + NamedQuery, + Datatype, + Alias, + Path, + JoinCondition, + Join, +) +from .packages.parseutils.meta import ColumnMetadata, ForeignKey +from .packages.parseutils.utils import last_word +from .packages.parseutils.tables import TableReference +from .packages.pgliterals.main import get_literals +from .packages.prioritization import PrevalenceCounter +from .config import load_config, config_location + +_logger = logging.getLogger(__name__) + +Match = namedtuple("Match", ["completion", "priority"]) + +_SchemaObject = namedtuple("SchemaObject", "name schema meta") + + +def SchemaObject(name, schema=None, meta=None): + return _SchemaObject(name, schema, meta) + + +_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display") + + +def Candidate( + completion, prio=None, meta=None, synonyms=None, prio2=None, display=None +): + return _Candidate( + completion, prio, meta, synonyms or [completion], prio2, display or completion + ) + + +# Used to strip trailing '::some_type' from default-value expressions +arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") + +normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' + + +def generate_alias(tbl, alias_map=None): + """Generate a table alias, consisting of all upper-case letters in + the table name, or, if there are no upper-case letters, the first letter + + all letters preceded by _ + param tbl - unescaped name of the table to alias + """ + if alias_map and tbl in alias_map: + return alias_map[tbl] + return "".join( + [l for l in tbl if l.isupper()] + or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] + ) + + +class InvalidMapFile(ValueError): + pass + + +def load_alias_map_file(path): + try: + with open(path) as fo: + alias_map = json.load(fo) + except FileNotFoundError as err: + raise InvalidMapFile( + f"Cannot read alias_map_file - {err.filename} does not exist" + ) + except json.JSONDecodeError: + raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json") + else: + return alias_map + + +class PGCompleter(Completer): + # keywords_tree: A dict mapping keywords to well known following keywords. + # e.g. 'CREATE': ['TABLE', 'USER', ...], + keywords_tree = get_literals("keywords", type_=dict) + keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values()))) + functions = get_literals("functions") + datatypes = get_literals("datatypes") + reserved_words = set(get_literals("reserved")) + + def __init__(self, smart_completion=True, pgspecial=None, settings=None): + super().__init__() + self.smart_completion = smart_completion + self.pgspecial = pgspecial + self.prioritizer = PrevalenceCounter() + settings = settings or {} + self.signature_arg_style = settings.get( + "signature_arg_style", "{arg_name} {arg_type}" + ) + self.call_arg_style = settings.get( + "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}" + ) + self.call_arg_display_style = settings.get( + "call_arg_display_style", "{arg_name}" + ) + self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) + self.search_path_filter = settings.get("search_path_filter") + self.generate_aliases = settings.get("generate_aliases") + alias_map_file = settings.get("alias_map_file") + if alias_map_file is not None: + self.alias_map = load_alias_map_file(alias_map_file) + else: + self.alias_map = None + self.casing_file = settings.get("casing_file") + self.insert_col_skip_patterns = [ + re.compile(pattern) + for pattern in settings.get( + "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("] + ) + ] + self.generate_casing_file = settings.get("generate_casing_file") + self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table") + self.asterisk_column_order = settings.get( + "asterisk_column_order", "table_order" + ) + + keyword_casing = settings.get("keyword_casing", "upper").lower() + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "upper" + self.keyword_casing = keyword_casing + self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$") + + self.databases = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.search_path = [] + self.casing = {} + + self.all_completions = set(self.keywords + self.functions) + + def escape_name(self, name): + if name and ( + (not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions) + ): + name = '"%s"' % name + + return name + + def escape_schema(self, name): + return "'{}'".format(self.unescape_name(name)) + + def unescape_name(self, name): + """Unquote a string.""" + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_schemata(self, schemata): + # schemata is a list of schema names + schemata = self.escaped_names(schemata) + metadata = self.dbmetadata["tables"] + for schema in schemata: + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + for schema in schemata: + metadata[schema] = {} + + self.all_completions.update(schemata) + + def extend_casing(self, words): + """extend casing data + + :return: + """ + # casing should be a dict {lowercasename:PreferredCasingName} + self.casing = {word.lower(): word for word in words} + + def extend_relations(self, data, kind): + """extend metadata for tables or views. + + :param data: list of (schema_name, rel_name) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + + data = [self.escaped_names(d) for d in data] + + # dbmetadata['tables']['schema_name']['table_name'] should be an + # OrderedDict {column_name:ColumnMetaData}. + metadata = self.dbmetadata[kind] + for schema, relname in data: + try: + metadata[schema][relname] = OrderedDict() + except KeyError: + _logger.error( + "%r %r listed in unrecognized schema %r", kind, relname, schema + ) + self.all_completions.add(relname) + + def extend_columns(self, column_data, kind): + """extend column metadata. + + :param column_data: list of (schema_name, rel_name, column_name, + column_type, has_default, default) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + metadata = self.dbmetadata[kind] + for schema, relname, colname, datatype, has_default, default in column_data: + (schema, relname, colname) = self.escaped_names([schema, relname, colname]) + column = ColumnMetadata( + name=colname, + datatype=datatype, + has_default=has_default, + default=default, + ) + metadata[schema][relname][colname] = column + self.all_completions.add(colname) + + def extend_functions(self, func_data): + # func_data is a list of function metadata namedtuples + + # dbmetadata['schema_name']['functions']['function_name'] should return + # the function metadata namedtuple for the corresponding function + metadata = self.dbmetadata["functions"] + + for f in func_data: + schema, func = self.escaped_names([f.schema_name, f.func_name]) + + if func in metadata[schema]: + metadata[schema][func].append(f) + else: + metadata[schema][func] = [f] + + self.all_completions.add(func) + + self._refresh_arg_list_cache() + + def _refresh_arg_list_cache(self): + # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}} + # This is used when suggesting functions, to avoid the latency that would result + # if we'd recalculate the arg lists each time we suggest functions (in large DBs) + self._arg_list_cache = { + usage: { + meta: self._arg_list(meta, usage) + for sch, funcs in self.dbmetadata["functions"].items() + for func, metas in funcs.items() + for meta in metas + } + for usage in ("call", "call_display", "signature") + } + + def extend_foreignkeys(self, fk_data): + # fk_data is a list of ForeignKey namedtuples, with fields + # parentschema, childschema, parenttable, childtable, + # parentcolumns, childcolumns + + # These are added as a list of ForeignKey namedtuples to the + # ColumnMetadata namedtuple for both the child and parent + meta = self.dbmetadata["tables"] + + for fk in fk_data: + e = self.escaped_names + parentschema, childschema = e([fk.parentschema, fk.childschema]) + parenttable, childtable = e([fk.parenttable, fk.childtable]) + childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) + childcolmeta = meta[childschema][childtable][childcol] + parcolmeta = meta[parentschema][parenttable][parcol] + fk = ForeignKey( + parentschema, parenttable, parcol, childschema, childtable, childcol + ) + childcolmeta.foreignkeys.append(fk) + parcolmeta.foreignkeys.append(fk) + + def extend_datatypes(self, type_data): + # dbmetadata['datatypes'][schema_name][type_name] should store type + # metadata, such as composite type field names. Currently, we're not + # storing any metadata beyond typename, so just store None + meta = self.dbmetadata["datatypes"] + + for t in type_data: + schema, type_name = self.escaped_names(t) + meta[schema][type_name] = None + self.all_completions.add(type_name) + + def extend_query_history(self, text, is_init=False): + if is_init: + # During completer initialization, only load keyword preferences, + # not names + self.prioritizer.update_keywords(text) + else: + self.prioritizer.update(text) + + def set_search_path(self, search_path): + self.search_path = self.escaped_names(search_path) + + def reset_completions(self): + self.databases = [] + self.special_commands = [] + self.search_path = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.all_completions = set(self.keywords + self.functions) + + def find_matches(self, text, collection, mode="fuzzy", meta=None): + """Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + `collection` can be either a list of strings or a list of Candidate + namedtuples. + `mode` can be either 'fuzzy', or 'strict' + 'fuzzy': fuzzy matching, ties broken by name prevalance + `keyword`: start only matching, ties broken by keyword prevalance + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + + """ + if not collection: + return [] + prio_order = [ + "keyword", + "function", + "view", + "table", + "datatype", + "database", + "schema", + "column", + "table alias", + "join", + "name join", + "fk join", + "table format", + ] + type_priority = prio_order.index(meta) if meta in prio_order else -1 + text = last_word(text, include="most_punctuations").lower() + text_len = len(text) + + if text and text[0] == '"': + # text starts with double quote; user is manually escaping a name + # Match on everything that follows the double-quote. Note that + # text_len is calculated before removing the quote, so the + # Completion.position value is correct + text = text[1:] + + if mode == "fuzzy": + fuzzy = True + priority_func = self.prioritizer.name_count + else: + fuzzy = False + priority_func = self.prioritizer.keyword_count + + # Construct a `_match` function for either fuzzy or non-fuzzy matching + # The match function returns a 2-tuple used for sorting the matches, + # or None if the item doesn't match + # Note: higher priority values mean more important, so use negative + # signs to flip the direction of the tuple + if fuzzy: + regex = ".*?".join(map(re.escape, text)) + pat = re.compile("(%s)" % regex) + + def _match(item): + if item.lower()[: len(text) + 1] in (text, text + " "): + # Exact match of first word in suggestion + # This is to get exact alias matches to the top + # E.g. for input `e`, 'Entries E' should be on top + # (before e.g. `EndUsers EU`) + return float("Infinity"), -1 + r = pat.search(self.unescape_name(item.lower())) + if r: + return -len(r.group()), -r.start() + + else: + match_end_limit = len(text) + + def _match(item): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + # Use negative infinity to force keywords to sort after all + # fuzzy matches + return -float("Infinity"), -match_point + + matches = [] + for cand in collection: + if isinstance(cand, _Candidate): + item, prio, display_meta, synonyms, prio2, display = cand + if display_meta is None: + display_meta = meta + syn_matches = (_match(x) for x in synonyms) + # Nones need to be removed to avoid max() crashing in Python 3 + syn_matches = [m for m in syn_matches if m] + sort_key = max(syn_matches) if syn_matches else None + else: + item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand + sort_key = _match(cand) + + if sort_key: + if display_meta and len(display_meta) > 50: + # Truncate meta-text to 50 characters, if necessary + display_meta = display_meta[:47] + "..." + + # Lexical order of items in the collection, used for + # tiebreaking items with the same match group length and start + # position. Since we use *higher* priority to mean "more + # important," we use -ord(c) to prioritize "aa" > "ab" and end + # with 1 to prioritize shorter strings (ie "user" > "users"). + # We first do a case-insensitive sort and then a + # case-sensitive one as a tie breaker. + # We also use the unescape_name to make sure quoted names have + # the same priority as unquoted names. + lexical_priority = ( + tuple( + 0 if c in " _" else -ord(c) + for c in self.unescape_name(item.lower()) + ) + + (1,) + + tuple(c for c in item) + ) + + item = self.case(item) + display = self.case(display) + priority = ( + sort_key, + type_priority, + prio, + priority_func(item), + prio2, + lexical_priority, + ) + matches.append( + Match( + completion=Completion( + text=item, + start_position=-text_len, + display_meta=display_meta, + display=display, + ), + priority=priority, + ) + ) + return matches + + def case(self, word): + return self.casing.get(word, word) + + def get_completions(self, document, complete_event, smart_completion=None): + word_before_cursor = document.get_word_before_cursor(WORD=True) + + if smart_completion is None: + smart_completion = self.smart_completion + + # If smart_completion is off then match any word that starts with + # 'word_before_cursor'. + if not smart_completion: + matches = self.find_matches( + word_before_cursor, self.all_completions, mode="strict" + ) + completions = [m.completion for m in matches] + return sorted(completions, key=operator.attrgetter("text")) + + matches = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + suggestion_type = type(suggestion) + _logger.debug("Suggestion type: %r", suggestion_type) + + # Map suggestion type to method + # e.g. 'table' -> self.get_table_matches + matcher = self.suggestion_matchers[suggestion_type] + matches.extend(matcher(self, suggestion, word_before_cursor)) + + # Sort matches so highest priorities are first + matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True) + + return [m.completion for m in matches] + + def get_column_matches(self, suggestion, word_before_cursor): + tables = suggestion.table_refs + do_qualify = ( + suggestion.qualifiable + and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] + ) + qualify = lambda col, tbl: ( + (tbl + "." + self.case(col)) if do_qualify else self.case(col) + ) + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) + + def make_cand(name, ref): + synonyms = (name, generate_alias(self.case(name))) + return Candidate(qualify(name, ref), 0, "column", synonyms) + + def flat_cols(): + return [ + make_cand(c.name, t.ref) + for t, cols in scoped_cols.items() + for c in cols + ] + + if suggestion.require_last_table: + # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should + # suggest only columns that appear in the last table and one more + ltbl = tables[-1].ref + other_tbl_cols = { + c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs + } + scoped_cols = { + t: [col for col in cols if col.name in other_tbl_cols] + for t, cols in scoped_cols.items() + if t.ref == ltbl + } + lastword = last_word(word_before_cursor, include="most_punctuations") + if lastword == "*": + if suggestion.context == "insert": + + def filter(col): + if not col.has_default: + return True + return not any( + p.match(col.default) for p in self.insert_col_skip_patterns + ) + + scoped_cols = { + t: [col for col in cols if filter(col)] + for t, cols in scoped_cols.items() + } + if self.asterisk_column_order == "alphabetic": + for cols in scoped_cols.values(): + cols.sort(key=operator.attrgetter("name")) + if ( + lastword != word_before_cursor + and len(tables) == 1 + and word_before_cursor[-len(lastword) - 1] == "." + ): + # User typed x.*; replicate "x." for all columns except the + # first, which gets the original (as we only replace the "*"") + sep = ", " + word_before_cursor[:-1] + collist = sep.join(self.case(c.completion) for c in flat_cols()) + else: + collist = ", ".join( + qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs + ) + + return [ + Match( + completion=Completion( + collist, -1, display_meta="columns", display="*" + ), + priority=(1, 1, 1), + ) + ] + + return self.find_matches(word_before_cursor, flat_cols(), meta="column") + + def alias(self, tbl, tbls): + """Generate a unique table alias + tbl - name of the table to alias, quoted if it needs to be + tbls - TableReference iterable of tables already in query + """ + tbl = self.case(tbl) + tbls = {normalize_ref(t.ref) for t in tbls} + if self.generate_aliases: + tbl = generate_alias(self.unescape_name(tbl)) + if normalize_ref(tbl) not in tbls: + return tbl + elif tbl[0] == '"': + aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2)) + else: + aliases = (tbl + str(i) for i in count(2)) + return next(a for a in aliases if normalize_ref(a) not in tbls) + + def get_join_matches(self, suggestion, word_before_cursor): + tbls = suggestion.table_refs + cols = self.populate_scoped_cols(tbls) + # Set up some data structures for efficient access + qualified = {normalize_ref(t.ref): t.schema for t in tbls} + ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)} + refs = {normalize_ref(t.ref) for t in tbls} + other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]} + joins = [] + # Iterate over FKs in existing tables to find potential joins + fks = ( + (fk, rtbl, rcol) + for rtbl, rcols in cols.items() + for rcol in rcols + for fk in rcol.foreignkeys + ) + col = namedtuple("col", "schema tbl col") + for fk, rtbl, rcol in fks: + right = col(rtbl.schema, rtbl.name, rcol.name) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left = child if parent == right else parent + if suggestion.schema and left.schema != suggestion.schema: + continue + c = self.case + if self.generate_aliases or normalize_ref(left.tbl) in refs: + lref = self.alias(left.tbl, suggestion.table_refs) + join = "{0} {4} ON {4}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref + ) + else: + join = "{0} ON {0}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col) + ) + alias = generate_alias(self.case(left.tbl)) + synonyms = [ + join, + "{0} ON {0}.{1} = {2}.{3}".format( + alias, c(left.col), rtbl.ref, c(right.col) + ), + ] + # Schema-qualify if (1) new table in same schema as old, and old + # is schema-qualified, or (2) new in other schema, except public + if not suggestion.schema and ( + qualified[normalize_ref(rtbl.ref)] + and left.schema == right.schema + or left.schema not in (right.schema, "public") + ): + join = left.schema + "." + join + prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( + 0 if (left.schema, left.tbl) in other_tbls else 1 + ) + joins.append(Candidate(join, prio, "join", synonyms=synonyms)) + + return self.find_matches(word_before_cursor, joins, meta="join") + + def get_join_condition_matches(self, suggestion, word_before_cursor): + col = namedtuple("col", "schema tbl col") + tbls = self.populate_scoped_cols(suggestion.table_refs).items + cols = [(t, c) for t, cs in tbls() for c in cs] + try: + lref = (suggestion.parent or suggestion.table_refs[-1]).ref + ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] + except IndexError: # The user typed an incorrect table qualifier + return [] + conds, found_conds = [], set() + + def add_cond(lcol, rcol, rref, prio, meta): + prefix = "" if suggestion.parent else ltbl.ref + "." + case = self.case + cond = prefix + case(lcol) + " = " + rref + "." + case(rcol) + if cond not in found_conds: + found_conds.add(cond) + conds.append(Candidate(cond, prio + ref_prio[rref], meta)) + + def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} + d = defaultdict(list) + for pair in pairs: + d[pair[0]].append(pair[1]) + return d + + # Tables that are closer to the cursor get higher prio + ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)} + # Map (schema, table, col) to tables + coldict = list_dict( + ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref + ) + # For each fk from the left table, generate a join condition if + # the other table is also in the scope + fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) + for fk, lcol in fks: + left = col(ltbl.schema, ltbl.name, lcol) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left, right = (child, par) if left == child else (par, child) + for rtbl in coldict[right]: + add_cond(left.col, right.col, rtbl.ref, 2000, "fk join") + # For name matching, use a {(colname, coltype): TableReference} dict + coltyp = namedtuple("coltyp", "name datatype") + col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) + # Find all name-match join conditions + for c in (coltyp(c.name, c.datatype) for c in lcols): + for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): + prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0 + add_cond(c.name, c.name, rtbl.ref, prio, "name join") + + return self.find_matches(word_before_cursor, conds, meta="join") + + def get_function_matches(self, suggestion, word_before_cursor, alias=False): + if suggestion.usage == "from": + # Only suggest functions allowed in FROM clause + + def filt(f): + return ( + not f.is_aggregate + and not f.is_window + and not f.is_extension + and ( + f.is_public + or f.schema_name in self.search_path + or f.schema_name == suggestion.schema + ) + ) + + else: + alias = False + + def filt(f): + return not f.is_extension and ( + f.is_public or f.schema_name == suggestion.schema + ) + + arg_mode = {"signature": "signature", "special": None}.get( + suggestion.usage, "call" + ) + + # Function overloading means we way have multiple functions of the same + # name at this point, so keep unique names only + all_functions = self.populate_functions(suggestion.schema, filt) + funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions} + + matches = self.find_matches(word_before_cursor, funcs, meta="function") + + if not suggestion.schema and not suggestion.usage: + # also suggest hardcoded functions using startswith matching + predefined_funcs = self.find_matches( + word_before_cursor, self.functions, mode="strict", meta="function" + ) + matches.extend(predefined_funcs) + + return matches + + def get_schema_matches(self, suggestion, word_before_cursor): + schema_names = self.dbmetadata["tables"].keys() + + # Unless we're sure the user really wants them, hide schema names + # starting with pg_, which are mostly temporary schemas + if not word_before_cursor.startswith("pg_"): + schema_names = [s for s in schema_names if not s.startswith("pg_")] + + if suggestion.quoted: + schema_names = [self.escape_schema(s) for s in schema_names] + + return self.find_matches(word_before_cursor, schema_names, meta="schema") + + def get_from_clause_item_matches(self, suggestion, word_before_cursor): + alias = self.generate_aliases + s = suggestion + t_sug = Table(s.schema, s.table_refs, s.local_tables) + v_sug = View(s.schema, s.table_refs) + f_sug = Function(s.schema, s.table_refs, usage="from") + return ( + self.get_table_matches(t_sug, word_before_cursor, alias) + + self.get_view_matches(v_sug, word_before_cursor, alias) + + self.get_function_matches(f_sug, word_before_cursor, alias) + ) + + def _arg_list(self, func, usage): + """Returns a an arg list string, e.g. `(_foo:=23)` for a func. + + :param func is a FunctionMetadata object + :param usage is 'call', 'call_display' or 'signature' + + """ + template = { + "call": self.call_arg_style, + "call_display": self.call_arg_display_style, + "signature": self.signature_arg_style, + }[usage] + args = func.args() + if not template: + return "()" + elif usage == "call" and len(args) < 2: + return "()" + elif usage == "call" and func.has_variadic(): + return "()" + multiline = usage == "call" and len(args) > self.call_arg_oneliner_max + max_arg_len = max(len(a.name) for a in args) if multiline else 0 + args = ( + self._format_arg(template, arg, arg_num + 1, max_arg_len) + for arg_num, arg in enumerate(args) + ) + if multiline: + return "(" + ",".join("\n " + a for a in args if a) + "\n)" + else: + return "(" + ", ".join(a for a in args if a) + ")" + + def _format_arg(self, template, arg, arg_num, max_arg_len): + if not template: + return None + if arg.has_default: + arg_default = "NULL" if arg.default is None else arg.default + # Remove trailing ::(schema.)type + arg_default = arg_default_type_strip_regex.sub("", arg_default) + else: + arg_default = "" + return template.format( + max_arg_len=max_arg_len, + arg_name=self.case(arg.name), + arg_num=arg_num, + arg_type=arg.datatype, + arg_default=arg_default, + ) + + def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): + """Returns a Candidate namedtuple. + + :param tbl is a SchemaObject + :param arg_mode determines what type of arg list to suffix for functions. + Possible values: call, signature + + """ + cased_tbl = self.case(tbl.name) + if do_alias: + alias = self.alias(cased_tbl, suggestion.table_refs) + synonyms = (cased_tbl, generate_alias(cased_tbl)) + maybe_alias = (" " + alias) if do_alias else "" + maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else "" + suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else "" + if arg_mode == "call": + display_suffix = self._arg_list_cache["call_display"][tbl.meta] + elif arg_mode == "signature": + display_suffix = self._arg_list_cache["signature"][tbl.meta] + else: + display_suffix = "" + item = maybe_schema + cased_tbl + suffix + maybe_alias + display = maybe_schema + cased_tbl + display_suffix + maybe_alias + prio2 = 0 if tbl.schema else 1 + return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) + + def get_table_matches(self, suggestion, word_before_cursor, alias=False): + tables = self.populate_schema_objects(suggestion.schema, "tables") + tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables) + + # Unless we're sure the user really wants them, don't suggest the + # pg_catalog tables that are implicitly on the search path + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + tables = [t for t in tables if not t.name.startswith("pg_")] + tables = [self._make_cand(t, alias, suggestion) for t in tables] + return self.find_matches(word_before_cursor, tables, meta="table") + + def get_table_formats(self, _, word_before_cursor): + formats = TabularOutputFormatter().supported_formats + return self.find_matches(word_before_cursor, formats, meta="table format") + + def get_view_matches(self, suggestion, word_before_cursor, alias=False): + views = self.populate_schema_objects(suggestion.schema, "views") + + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + views = [v for v in views if not v.name.startswith("pg_")] + views = [self._make_cand(v, alias, suggestion) for v in views] + return self.find_matches(word_before_cursor, views, meta="view") + + def get_alias_matches(self, suggestion, word_before_cursor): + aliases = suggestion.aliases + return self.find_matches(word_before_cursor, aliases, meta="table alias") + + def get_database_matches(self, _, word_before_cursor): + return self.find_matches(word_before_cursor, self.databases, meta="database") + + def get_keyword_matches(self, suggestion, word_before_cursor): + keywords = self.keywords_tree.keys() + # Get well known following keywords for the last token. If any, narrow + # candidates to this list. + next_keywords = self.keywords_tree.get(suggestion.last_token, []) + if next_keywords: + keywords = next_keywords + + casing = self.keyword_casing + if casing == "auto": + if word_before_cursor and word_before_cursor[-1].islower(): + casing = "lower" + else: + casing = "upper" + + if casing == "upper": + keywords = [k.upper() for k in keywords] + else: + keywords = [k.lower() for k in keywords] + + return self.find_matches( + word_before_cursor, keywords, mode="strict", meta="keyword" + ) + + def get_path_matches(self, _, word_before_cursor): + completer = PathCompleter(expanduser=True) + document = Document( + text=word_before_cursor, cursor_position=len(word_before_cursor) + ) + for c in completer.get_completions(document, None): + yield Match(completion=c, priority=(0,)) + + def get_special_matches(self, _, word_before_cursor): + if not self.pgspecial: + return [] + + commands = self.pgspecial.commands + cmds = commands.keys() + cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds] + return self.find_matches(word_before_cursor, cmds, mode="strict") + + def get_datatype_matches(self, suggestion, word_before_cursor): + # suggest custom datatypes + types = self.populate_schema_objects(suggestion.schema, "datatypes") + types = [self._make_cand(t, False, suggestion) for t in types] + matches = self.find_matches(word_before_cursor, types, meta="datatype") + + if not suggestion.schema: + # Also suggest hardcoded types + matches.extend( + self.find_matches( + word_before_cursor, self.datatypes, mode="strict", meta="datatype" + ) + ) + + return matches + + def get_namedquery_matches(self, _, word_before_cursor): + return self.find_matches( + word_before_cursor, NamedQueries.instance.list(), meta="named query" + ) + + suggestion_matchers = { + FromClauseItem: get_from_clause_item_matches, + JoinCondition: get_join_condition_matches, + Join: get_join_matches, + Column: get_column_matches, + Function: get_function_matches, + Schema: get_schema_matches, + Table: get_table_matches, + TableFormat: get_table_formats, + View: get_view_matches, + Alias: get_alias_matches, + Database: get_database_matches, + Keyword: get_keyword_matches, + Special: get_special_matches, + Datatype: get_datatype_matches, + NamedQuery: get_namedquery_matches, + Path: get_path_matches, + } + + def populate_scoped_cols(self, scoped_tbls, local_tbls=()): + """Find all columns in a set of scoped_tables. + + :param scoped_tbls: list of TableReference namedtuples + :param local_tbls: tuple(TableMetadata) + :return: {TableReference:{colname:ColumnMetaData}} + + """ + ctes = {normalize_ref(t.name): t.columns for t in local_tbls} + columns = OrderedDict() + meta = self.dbmetadata + + def addcols(schema, rel, alias, reltype, cols): + tbl = TableReference(schema, rel, alias, reltype == "functions") + if tbl not in columns: + columns[tbl] = [] + columns[tbl].extend(cols) + + for tbl in scoped_tbls: + # Local tables should shadow database tables + if tbl.schema is None and normalize_ref(tbl.name) in ctes: + cols = ctes[normalize_ref(tbl.name)] + addcols(None, tbl.name, "CTE", tbl.alias, cols) + continue + schemas = [tbl.schema] if tbl.schema else self.search_path + for schema in schemas: + relname = self.escape_name(tbl.name) + schema = self.escape_name(schema) + if tbl.is_function: + # Return column names from a set-returning function + # Get an array of FunctionMetadata objects + functions = meta["functions"].get(schema, {}).get(relname) + for func in functions or []: + # func is a FunctionMetadata object + cols = func.fields() + addcols(schema, relname, tbl.alias, "functions", cols) + else: + for reltype in ("tables", "views"): + cols = meta[reltype].get(schema, {}).get(relname) + if cols: + cols = cols.values() + addcols(schema, relname, tbl.alias, reltype, cols) + break + + return columns + + def _get_schemas(self, obj_typ, schema): + """Returns a list of schemas from which to suggest objects. + + :param schema is the schema qualification input by the user (if any) + + """ + metadata = self.dbmetadata[obj_typ] + if schema: + schema = self.escape_name(schema) + return [schema] if schema in metadata else [] + return self.search_path if self.search_path_filter else metadata.keys() + + def _maybe_schema(self, schema, parent): + return None if parent or schema in self.search_path else schema + + def populate_schema_objects(self, schema, obj_type): + """Returns a list of SchemaObjects representing tables or views. + + :param schema is the schema qualification input by the user (if any) + + """ + + return [ + SchemaObject( + name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) + ) + for sch in self._get_schemas(obj_type, schema) + for obj in self.dbmetadata[obj_type][sch].keys() + ] + + def populate_functions(self, schema, filter_func): + """Returns a list of function SchemaObjects. + + :param filter_func is a function that accepts a FunctionMetadata + namedtuple and returns a boolean indicating whether that + function should be kept or discarded + + """ + + # Because of multiple dispatch, we can have multiple functions + # with the same name, which is why `for meta in metas` is necessary + # in the comprehensions below + return [ + SchemaObject( + name=func, + schema=(self._maybe_schema(schema=sch, parent=schema)), + meta=meta, + ) + for sch in self._get_schemas("functions", schema) + for (func, metas) in self.dbmetadata["functions"][sch].items() + for meta in metas + if filter_func(meta) + ] diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py new file mode 100644 index 0000000..497d681 --- /dev/null +++ b/pgcli/pgexecute.py @@ -0,0 +1,868 @@ +import logging +import traceback +from collections import namedtuple +import re +import pgspecial as special +import psycopg +import psycopg.sql +from psycopg.conninfo import make_conninfo +import sqlparse + +from .packages.parseutils.meta import FunctionMetadata, ForeignKey + +_logger = logging.getLogger(__name__) + +ViewDef = namedtuple( + "ViewDef", "nspname relname relkind viewdef reloptions checkoption" +) + + +# we added this funcion to strip beginning comments +# because sqlparse didn't handle tem well. It won't be needed if sqlparse +# does parsing of this situation better + + +def remove_beginning_comments(command): + # Regular expression pattern to match comments + pattern = r"^(/\*.*?\*/|--.*?)(?:\n|$)" + + # Find and remove all comments from the beginning + cleaned_command = command + comments = [] + match = re.match(pattern, cleaned_command, re.DOTALL) + while match: + comments.append(match.group()) + cleaned_command = cleaned_command[len(match.group()) :].lstrip() + match = re.match(pattern, cleaned_command, re.DOTALL) + + return [cleaned_command, comments] + + +def register_typecasters(connection): + """Casts date and timestamp values to string, resolves issues with out-of-range + dates (e.g. BC) which psycopg can't handle""" + for forced_text_type in [ + "date", + "time", + "timestamp", + "timestamptz", + "bytea", + "json", + "jsonb", + ]: + connection.adapters.register_loader( + forced_text_type, psycopg.types.string.TextLoader + ) + + +# pg3: I don't know what is this +class ProtocolSafeCursor(psycopg.Cursor): + """This class wraps and suppresses Protocol Errors with pgbouncer database. + See https://github.com/dbcli/pgcli/pull/1097. + Pgbouncer database is a virtual database with its own set of commands.""" + + def __init__(self, *args, **kwargs): + self.protocol_error = False + self.protocol_message = "" + super().__init__(*args, **kwargs) + + def __iter__(self): + if self.protocol_error: + raise StopIteration + return super().__iter__() + + def fetchall(self): + if self.protocol_error: + return [(self.protocol_message,)] + return super().fetchall() + + def fetchone(self): + if self.protocol_error: + return (self.protocol_message,) + return super().fetchone() + + # def mogrify(self, query, params): + # args = [Literal(v).as_string(self.connection) for v in params] + # return query % tuple(args) + # + def execute(self, *args, **kwargs): + try: + super().execute(*args, **kwargs) + self.protocol_error = False + self.protocol_message = "" + except psycopg.errors.ProtocolViolation as ex: + self.protocol_error = True + self.protocol_message = str(ex) + _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) + + +class PGExecute: + # The boolean argument to the current_schemas function indicates whether + # implicit schemas, e.g. pg_catalog + search_path_query = """ + SELECT * FROM unnest(current_schemas(true))""" + + schemata_query = """ + SELECT nspname + FROM pg_catalog.pg_namespace + ORDER BY 1 """ + + tables_query = """ + SELECT n.nspname schema_name, + c.relname table_name + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = c.relnamespace + WHERE c.relkind = ANY(%s) + ORDER BY 1,2;""" + + databases_query = """ + SELECT d.datname + FROM pg_catalog.pg_database d + ORDER BY 1""" + + full_databases_query = """ + SELECT d.datname as "Name", + pg_catalog.pg_get_userbyid(d.datdba) as "Owner", + pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding", + d.datcollate as "Collate", + d.datctype as "Ctype", + pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges" + FROM pg_catalog.pg_database d + ORDER BY 1""" + + socket_directory_query = """ + SELECT setting + FROM pg_settings + WHERE name = 'unix_socket_directories' + """ + + view_definition_query = """ + WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid) + SELECT nspname, relname, relkind, + pg_catalog.pg_get_viewdef(c.oid, true), + array_remove(array_remove(c.reloptions,'check_option=local'), + 'check_option=cascaded') AS reloptions, + CASE + WHEN 'check_option=local' = ANY (c.reloptions) THEN 'LOCAL'::text + WHEN 'check_option=cascaded' = ANY (c.reloptions) THEN 'CASCADED'::text + ELSE NULL + END AS checkoption + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid) + JOIN v ON (c.oid = v.v_oid)""" + + function_definition_query = """ + WITH f AS + (SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid) + SELECT pg_catalog.pg_get_functiondef(f.f_oid) + FROM f""" + + def __init__( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs, + ): + self._conn_params = {} + self._is_virtual_database = None + self.conn = None + self.dbname = None + self.user = None + self.password = None + self.host = None + self.port = None + self.server_version = None + self.extra_args = None + self.connect(database, user, password, host, port, dsn, **kwargs) + self.reset_expanded = None + + def is_virtual_database(self): + if self._is_virtual_database is None: + self._is_virtual_database = self.is_protocol_error() + return self._is_virtual_database + + def copy(self): + """Returns a clone of the current executor.""" + return self.__class__(**self._conn_params) + + def connect( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs, + ): + conn_params = self._conn_params.copy() + + new_params = { + "dbname": database, + "user": user, + "password": password, + "host": host, + "port": port, + "dsn": dsn, + } + new_params.update(kwargs) + + if new_params["dsn"]: + new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} + + if new_params["password"]: + new_params["dsn"] = make_conninfo( + new_params["dsn"], password=new_params.pop("password") + ) + + conn_params.update({k: v for k, v in new_params.items() if v}) + + if "dsn" in conn_params: + other_params = {k: v for k, v in conn_params.items() if k != "dsn"} + conn_info = make_conninfo(conn_params["dsn"], **other_params) + else: + conn_info = make_conninfo(**conn_params) + conn = psycopg.connect(conn_info) + conn.cursor_factory = ProtocolSafeCursor + + self._conn_params = conn_params + if self.conn: + self.conn.close() + self.conn = conn + self.conn.autocommit = True + + # When we connect using a DSN, we don't really know what db, + # user, etc. we connected to. Let's read it. + # Note: moved this after setting autocommit because of #664. + dsn_parameters = conn.info.get_parameters() + + if dsn_parameters: + self.dbname = dsn_parameters.get("dbname") + self.user = dsn_parameters.get("user") + self.host = dsn_parameters.get("host") + self.port = dsn_parameters.get("port") + else: + self.dbname = conn_params.get("database") + self.user = conn_params.get("user") + self.host = conn_params.get("host") + self.port = conn_params.get("port") + + self.password = password + self.extra_args = kwargs + + if not self.host: + self.host = ( + "pgbouncer" + if self.is_virtual_database() + else self.get_socket_directory() + ) + + self.pid = conn.info.backend_pid + self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1") + self.server_version = conn.info.parameter_status("server_version") or "" + + # _set_wait_callback(self.is_virtual_database()) + + if not self.is_virtual_database(): + register_typecasters(conn) + + @property + def short_host(self): + if "," in self.host: + host, _, _ = self.host.partition(",") + else: + host = self.host + short_host, _, _ = host.partition(".") + return short_host + + def _select_one(self, cur, sql): + """ + Helper method to run a select and retrieve a single field value + :param cur: cursor + :param sql: string + :return: string + """ + cur.execute(sql) + return cur.fetchone() + + def failed_transaction(self): + return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR + + def valid_transaction(self): + status = self.conn.info.transaction_status + return ( + status == psycopg.pq.TransactionStatus.ACTIVE + or status == psycopg.pq.TransactionStatus.INTRANS + ) + + def run( + self, + statement, + pgspecial=None, + exception_formatter=None, + on_error_resume=False, + explain_mode=False, + ): + """Execute the sql in the database and return the results. + + :param statement: A string containing one or more sql statements + :param pgspecial: PGSpecial object + :param exception_formatter: A callable that accepts an Exception and + returns a formatted (title, rows, headers, status) tuple that can + act as a query result. If an exception_formatter is not supplied, + psycopg2 exceptions are always raised. + :param on_error_resume: Bool. If true, queries following an exception + (assuming exception_formatter has been supplied) continue to + execute. + + :return: Generator yielding tuples containing + (title, rows, headers, status, query, success, is_special) + """ + + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield None, None, None, None, statement, False, False + + # sql parse doesn't split on a comment first + special + # so we're going to do it + + removed_comments = [] + sqlarr = [] + cleaned_command = "" + + # could skip if statement doesn't match ^-- or ^/* + cleaned_command, removed_comments = remove_beginning_comments(statement) + + sqlarr = sqlparse.split(cleaned_command) + + # now re-add the beginning comments if there are any, so that they show up in + # log files etc when running these commands + + if len(removed_comments) > 0: + sqlarr = removed_comments + sqlarr + + # run each sql query + for sql in sqlarr: + # Remove spaces, eol and semi-colons. + sql = sql.rstrip(";") + sql = sqlparse.format(sql, strip_comments=False).strip() + if not sql: + continue + try: + if explain_mode: + sql = self.explain_prefix() + sql + elif pgspecial: + # \G is treated specially since we have to set the expanded output. + if sql.endswith("\\G"): + if not pgspecial.expanded_output: + pgspecial.expanded_output = True + self.reset_expanded = True + sql = sql[:-2].strip() + + # First try to run each query as special + _logger.debug("Trying a pgspecial command. sql: %r", sql) + try: + cur = self.conn.cursor() + except psycopg.InterfaceError: + # edge case when connection is already closed, but we + # don't need cursor for special_cmd.arg_type == NO_QUERY. + # See https://github.com/dbcli/pgcli/issues/1014. + cur = None + try: + response = pgspecial.execute(cur, sql) + if cur and cur.protocol_error: + yield None, None, None, cur.protocol_message, statement, False, False + # this would close connection. We should reconnect. + self.connect() + continue + for result in response: + # e.g. execute_from_file already appends these + if len(result) < 7: + yield result + (sql, True, True) + else: + yield result + continue + except special.CommandNotFound: + pass + + # Not a special command, so execute as normal sql + yield self.execute_normal_sql(sql) + (sql, True, False) + except psycopg.DatabaseError as e: + _logger.error("sql: %r, error: %r", sql, e) + _logger.error("traceback: %r", traceback.format_exc()) + + if self._must_raise(e) or not exception_formatter: + raise + + yield None, None, None, exception_formatter(e), sql, False, False + + if not on_error_resume: + break + finally: + if self.reset_expanded: + pgspecial.expanded_output = False + self.reset_expanded = None + + def _must_raise(self, e): + """Return true if e is an error that should not be caught in ``run``. + + An uncaught error will prompt the user to reconnect; as long as we + detect that the connection is still open, we catch the error, as + reconnecting won't solve that problem. + + :param e: DatabaseError. An exception raised while executing a query. + + :return: Bool. True if ``run`` must raise this exception. + + """ + return self.conn.closed != 0 + + def execute_normal_sql(self, split_sql): + """Returns tuple (title, rows, headers, status)""" + _logger.debug("Regular sql statement. sql: %r", split_sql) + + title = "" + + def handle_notices(n): + nonlocal title + title = f"{n.message_primary}\n{n.message_detail}\n{title}" + + self.conn.add_notice_handler(handle_notices) + + if self.is_virtual_database() and "show help" in split_sql.lower(): + # see https://github.com/psycopg/psycopg/issues/303 + # special case "show help" in pgbouncer + res = self.conn.pgconn.exec_(split_sql.encode()) + return title, None, None, res.command_status.decode() + + cur = self.conn.cursor() + cur.execute(split_sql) + + # cur.description will be None for operations that do not return + # rows. + if cur.description: + headers = [x[0] for x in cur.description] + return title, cur, headers, cur.statusmessage + elif cur.protocol_error: + _logger.debug("Protocol error, unsupported command.") + return title, None, None, cur.protocol_message + else: + _logger.debug("No rows in result.") + return title, None, None, cur.statusmessage + + def search_path(self): + """Returns the current search path as a list of schema names""" + + try: + with self.conn.cursor() as cur: + _logger.debug("Search path query. sql: %r", self.search_path_query) + cur.execute(self.search_path_query) + return [x[0] for x in cur.fetchall()] + except psycopg.ProgrammingError: + fallback = "SELECT * FROM current_schemas(true)" + with self.conn.cursor() as cur: + _logger.debug("Search path query. sql: %r", fallback) + cur.execute(fallback) + return cur.fetchone()[0] + + def view_definition(self, spec): + """Returns the SQL defining views described by `spec`""" + + # 2: relkind, v or m (materialized) + # 4: reloptions, null + # 5: checkoption: local or cascaded + with self.conn.cursor() as cur: + sql = self.view_definition_query + _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec) + try: + cur.execute(sql, (spec,)) + except psycopg.ProgrammingError: + raise RuntimeError(f"View {spec} does not exist.") + result = ViewDef(*cur.fetchone()) + if result.relkind == "m": + template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}" + else: + template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}" + return ( + psycopg.sql.SQL(template) + .format( + name=psycopg.sql.Identifier(result.nspname, result.relname), + stmt=psycopg.sql.SQL(result.viewdef), + ) + .as_string(self.conn) + ) + + def function_definition(self, spec): + """Returns the SQL defining functions described by `spec`""" + + with self.conn.cursor() as cur: + sql = self.function_definition_query + _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec) + try: + cur.execute(sql, (spec,)) + result = cur.fetchone() + return result[0] + except psycopg.ProgrammingError: + raise RuntimeError(f"Function {spec} does not exist.") + + def schemata(self): + """Returns a list of schema names in the database""" + + with self.conn.cursor() as cur: + _logger.debug("Schemata Query. sql: %r", self.schemata_query) + cur.execute(self.schemata_query) + return [x[0] for x in cur.fetchall()] + + def _relations(self, kinds=("r", "p", "f", "v", "m")): + """Get table or view name metadata + + :param kinds: list of postgres relkind filters: + 'r' - table + 'p' - partitioned table + 'f' - foreign table + 'v' - view + 'm' - materialized view + :return: (schema_name, rel_name) tuples + """ + + with self.conn.cursor() as cur: + # sql = cur.mogrify(self.tables_query, kinds) + # _logger.debug("Tables Query. sql: %r", sql) + cur.execute(self.tables_query, [kinds]) + yield from cur + + def tables(self): + """Yields (schema_name, table_name) tuples""" + yield from self._relations(kinds=["r", "p", "f"]) + + def views(self): + """Yields (schema_name, view_name) tuples. + + Includes both views and and materialized views + """ + yield from self._relations(kinds=["v", "m"]) + + def _columns(self, kinds=("r", "p", "f", "v", "m")): + """Get column metadata for tables and views + + :param kinds: kinds: list of postgres relkind filters: + 'r' - table + 'p' - partitioned table + 'f' - foreign table + 'v' - view + 'm' - materialized view + :return: list of (schema_name, relation_name, column_name, column_type) tuples + """ + + if self.conn.info.server_version >= 80400: + columns_query = """ + SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + att.atttypid::regtype::text type_name, + att.atthasdef AS has_default, + pg_catalog.pg_get_expr(def.adbin, def.adrelid, true) as default + FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + LEFT OUTER JOIN pg_attrdef def + ON def.adrelid = att.attrelid + AND def.adnum = att.attnum + WHERE cls.relkind = ANY(%s) + AND NOT att.attisdropped + AND att.attnum > 0 + ORDER BY 1, 2, att.attnum""" + else: + columns_query = """ + SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + typ.typname type_name, + NULL AS has_default, + NULL AS default + FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + INNER JOIN pg_catalog.pg_type typ + ON typ.oid = att.atttypid + WHERE cls.relkind = ANY(%s) + AND NOT att.attisdropped + AND att.attnum > 0 + ORDER BY 1, 2, att.attnum""" + + with self.conn.cursor() as cur: + # sql = cur.mogrify(columns_query, kinds) + # _logger.debug("Columns Query. sql: %r", sql) + cur.execute(columns_query, [kinds]) + yield from cur + + def table_columns(self): + yield from self._columns(kinds=["r", "p", "f"]) + + def view_columns(self): + yield from self._columns(kinds=["v", "m"]) + + def databases(self): + with self.conn.cursor() as cur: + _logger.debug("Databases Query. sql: %r", self.databases_query) + cur.execute(self.databases_query) + return [x[0] for x in cur.fetchall()] + + def full_databases(self): + with self.conn.cursor() as cur: + _logger.debug("Databases Query. sql: %r", self.full_databases_query) + cur.execute(self.full_databases_query) + headers = [x[0] for x in cur.description] + return cur.fetchall(), headers, cur.statusmessage + + def is_protocol_error(self): + query = "SELECT 1" + with self.conn.cursor() as cur: + _logger.debug("Simple Query. sql: %r", query) + cur.execute(query) + return bool(cur.protocol_error) + + def get_socket_directory(self): + with self.conn.cursor() as cur: + _logger.debug( + "Socket directory Query. sql: %r", self.socket_directory_query + ) + cur.execute(self.socket_directory_query) + result = cur.fetchone() + return result[0] if result else "" + + def foreignkeys(self): + """Yields ForeignKey named tuples""" + + if self.conn.info.server_version < 90000: + return + + with self.conn.cursor() as cur: + query = """ + SELECT s_p.nspname AS parentschema, + t_p.relname AS parenttable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.confrelid + )) AS parentcolumn, + s_c.nspname AS childschema, + t_c.relname AS childtable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.conrelid + )) AS childcolumn + FROM pg_catalog.pg_constraint fk + JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid + JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace + JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid + JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace + WHERE fk.contype = 'f'; + """ + _logger.debug("Functions Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield ForeignKey(*row) + + def functions(self): + """Yields FunctionMetadata named tuples""" + + if self.conn.info.server_version >= 110000: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text return_type, + p.prokind = 'a' is_aggregate, + p.prokind = 'w' is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + pg_get_expr(proargdefaults, 0) AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + elif self.conn.info.server_version > 90000: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text return_type, + p.proisagg is_aggregate, + p.proiswindow is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + pg_get_expr(proargdefaults, 0) AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + elif self.conn.info.server_version >= 80400: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text, + p.proisagg is_aggregate, + false is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + NULL AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + else: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + NULL arg_types, + NULL arg_modes, + '' ret_type, + p.proisagg is_aggregate, + false is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + NULL AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + + with self.conn.cursor() as cur: + _logger.debug("Functions Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield FunctionMetadata(*row) + + def datatypes(self): + """Yields tuples of (schema_name, type_name)""" + + with self.conn.cursor() as cur: + if self.conn.info.server_version > 90000: + query = """ + SELECT n.nspname schema_name, + t.typname type_name + FROM pg_catalog.pg_type t + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = t.typnamespace + WHERE ( t.typrelid = 0 -- non-composite types + OR ( -- composite type, but not a table + SELECT c.relkind = 'c' + FROM pg_catalog.pg_class c + WHERE c.oid = t.typrelid + ) + ) + AND NOT EXISTS( -- ignore array types + SELECT 1 + FROM pg_catalog.pg_type el + WHERE el.oid = t.typelem AND el.typarray = t.oid + ) + AND n.nspname <> 'pg_catalog' + AND n.nspname <> 'information_schema' + ORDER BY 1, 2; + """ + else: + query = """ + SELECT n.nspname schema_name, + pg_catalog.format_type(t.oid, NULL) type_name + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid)) + AND t.typname !~ '^_' + AND n.nspname <> 'pg_catalog' + AND n.nspname <> 'information_schema' + AND pg_catalog.pg_type_is_visible(t.oid) + ORDER BY 1, 2; + """ + _logger.debug("Datatypes Query. sql: %r", query) + cur.execute(query) + yield from cur + + def casing(self): + """Yields the most common casing for names used in db functions""" + with self.conn.cursor() as cur: + query = r""" + WITH Words AS ( + SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1) + FROM pg_catalog.pg_proc P + JOIN pg_catalog.pg_namespace N ON N.oid = P.pronamespace + JOIN pg_catalog.pg_language L ON L.oid = P.prolang + WHERE L.lanname IN ('sql', 'plpgsql') + AND N.nspname NOT IN ('pg_catalog', 'information_schema') + GROUP BY Word + ), + OrderWords AS ( + SELECT Word, + ROW_NUMBER() OVER(PARTITION BY LOWER(Word) ORDER BY Count DESC) + FROM Words + WHERE Word ~* '.*[a-z].*' + ), + Names AS ( + --Column names + SELECT attname AS Name + FROM pg_catalog.pg_attribute + UNION -- Table/view names + SELECT relname + FROM pg_catalog.pg_class + UNION -- Function names + SELECT proname + FROM pg_catalog.pg_proc + UNION -- Type names + SELECT typname + FROM pg_catalog.pg_type + UNION -- Schema names + SELECT nspname + FROM pg_catalog.pg_namespace + UNION -- Parameter names + SELECT unnest(proargnames) + FROM pg_proc + ) + SELECT Word + FROM OrderWords + WHERE LOWER(Word) IN (SELECT Name FROM Names) + AND Row_Number = 1; + """ + _logger.debug("Casing Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield row[0] + + def explain_prefix(self): + return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) " diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py new file mode 100644 index 0000000..77874f4 --- /dev/null +++ b/pgcli/pgstyle.py @@ -0,0 +1,116 @@ +import logging + +import pygments.styles +from pygments.token import string_to_tokentype, Token +from pygments.style import Style as PygmentsStyle +from pygments.util import ClassNotFound +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles import merge_styles, Style + +logger = logging.getLogger(__name__) + +# map Pygments tokens (ptk 1.0) to class names (ptk 2.0). +TOKEN_TO_PROMPT_STYLE = { + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", + Token.Output.Null: "output.null", + Token.Literal.String: "literal.string", + Token.Literal.Number: "literal.number", + Token.Keyword: "keyword", + Token.Prompt: "prompt", + Token.Continuation: "continuation", +} + +# reverse dict for cli_helpers, because they still expect Pygments tokens. +PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} + + +def parse_pygments_style(token_name, style_object, style_dict): + """Parse token type and style string. + + :param token_name: str name of Pygments token. Example: "Token.String" + :param style_object: pygments.style.Style instance to use as base + :param style_dict: dict of token names and their styles, customized to this cli + + """ + token_type = string_to_tokentype(token_name) + try: + other_token_type = string_to_tokentype(style_dict[token_name]) + return token_type, style_object.styles[other_token_type] + except AttributeError: + return token_type, style_dict[token_name] + + +def style_factory(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name) + except ClassNotFound: + style = pygments.styles.get_style_by_name("native") + + prompt_styles = [] + # prompt-toolkit used pygments tokens for styling before, switched to style + # names in 2.0. Convert old token types to new style names, for backwards compatibility. + for token in cli_style: + if token.startswith("Token."): + # treat as pygments token (1.0) + token_type, style_value = parse_pygments_style(token, style, cli_style) + if token_type in TOKEN_TO_PROMPT_STYLE: + prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] + prompt_styles.append((prompt_style, style_value)) + else: + # we don't want to support tokens anymore + logger.error("Unhandled style / class name: %s", token) + else: + # treat as prompt style name (2.0). See default style names here: + # https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/src/prompt_toolkit/styles/defaults.py + prompt_styles.append((token, cli_style[token])) + + override_style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles( + [style_from_pygments_cls(style), override_style, Style(prompt_styles)] + ) + + +def style_factory_output(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name).styles + except ClassNotFound: + style = pygments.styles.get_style_by_name("native").styles + + for token in cli_style: + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) + style.update({token_type: style_value}) + elif token in PROMPT_STYLE_TO_TOKEN: + token_type = PROMPT_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + else: + # TODO: cli helpers will have to switch to ptk.Style + logger.error("Unhandled style / class name: %s", token) + + class OutputStyle(PygmentsStyle): + default_style = "" + styles = style + + return OutputStyle diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py new file mode 100644 index 0000000..4a12ff4 --- /dev/null +++ b/pgcli/pgtoolbar.py @@ -0,0 +1,71 @@ +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.application import get_app + +vi_modes = { + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.INSERT_MULTIPLE: "M", +} +# REPLACE_SINGLE is available in prompt_toolkit >= 3.0.6 +if "REPLACE_SINGLE" in {e.name for e in InputMode}: + vi_modes[InputMode.REPLACE_SINGLE] = "R" + + +def _get_vi_mode(): + return vi_modes[get_app().vi_state.input_mode] + + +def create_toolbar_tokens_func(pgcli): + """Return a function that generates the toolbar tokens.""" + + def get_toolbar_tokens(): + result = [] + result.append(("class:bottom-toolbar", " ")) + + if pgcli.completer.smart_completion: + result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF ")) + + if pgcli.multi_line: + result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) + + if pgcli.multi_line: + if pgcli.multiline_mode == "safe": + result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) ")) + else: + result.append( + ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") + ) + + if pgcli.vi_mode: + result.append( + ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ") ") + ) + else: + result.append(("class:bottom-toolbar", "[F4] Emacs-mode ")) + + if pgcli.explain_mode: + result.append(("class:bottom-toolbar", "[F5] Explain: ON ")) + else: + result.append(("class:bottom-toolbar", "[F5] Explain: OFF ")) + + if pgcli.pgexecute.failed_transaction(): + result.append( + ("class:bottom-toolbar.transaction.failed", " Failed transaction") + ) + + if pgcli.pgexecute.valid_transaction(): + result.append( + ("class:bottom-toolbar.transaction.valid", " Transaction") + ) + + if pgcli.completion_refresher.is_refreshing(): + result.append(("class:bottom-toolbar", " Refreshing completions...")) + + return result + + return get_toolbar_tokens diff --git a/pgcli/pyev.py b/pgcli/pyev.py new file mode 100644 index 0000000..2886c9c --- /dev/null +++ b/pgcli/pyev.py @@ -0,0 +1,439 @@ +import textwrap +import re +from click import style as color + +DESCRIPTIONS = { + "Append": "Used in a UNION to merge multiple record sets by appending them together.", + "Limit": "Returns a specified number of rows from a record set.", + "Sort": "Sorts a record set based on the specified sort key.", + "Nested Loop": "Merges two record sets by looping through every record in the first set and trying to find a match in the second set. All matching records are returned.", + "Merge Join": "Merges two record sets by first sorting them on a join key.", + "Hash": "Generates a hash table from the records in the input recordset. Hash is used by Hash Join.", + "Hash Join": "Joins to record sets by hashing one of them (using a Hash Scan).", + "Aggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()).", + "Hashaggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()). Hash Aggregate uses a hash to first organize the records by a key.", + "Sequence Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).", + "Seq Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).", + "Index Scan": "Finds relevant records based on an Index. Index Scans perform 2 read operations: one to read the index and another to read the actual value from the table.", + "Index Only Scan": "Finds relevant records based on an Index. Index Only Scans perform a single read operation from the index and do not read from the corresponding table.", + "Bitmap Heap Scan": "Searches through the pages returned by the Bitmap Index Scan for relevant rows.", + "Bitmap Index Scan": "Uses a Bitmap Index (index which uses 1 bit per page) to find all relevant pages. Results of this node are fed to the Bitmap Heap Scan.", + "CTEScan": "Performs a sequential scan of Common Table Expression (CTE) query results. Note that results of a CTE are materialized (calculated and temporarily stored).", + "ProjectSet": "ProjectSet appears when the SELECT or ORDER BY clause of the query. They basically just execute the set-returning function(s) for each tuple until none of the functions return any more records.", + "Result": "Returns result", +} + + +class Visualizer: + def __init__(self, terminal_width=100, color=True): + self.color = color + self.terminal_width = terminal_width + self.string_lines = [] + + def load(self, explain_dict): + self.plan = explain_dict.pop("Plan") + self.explain = explain_dict + self.process_all() + self.generate_lines() + + def process_all(self): + self.plan = self.process_plan(self.plan) + self.plan = self.calculate_outlier_nodes(self.plan) + + # + def process_plan(self, plan): + plan = self.calculate_planner_estimate(plan) + plan = self.calculate_actuals(plan) + self.calculate_maximums(plan) + # + for index in range(len(plan.get("Plans", []))): + _plan = plan["Plans"][index] + plan["Plans"][index] = self.process_plan(_plan) + return plan + + def prefix_format(self, v): + if self.color: + return color(v, fg="bright_black") + return v + + def tag_format(self, v): + if self.color: + return color(v, fg="white", bg="red") + return v + + def muted_format(self, v): + if self.color: + return color(v, fg="bright_black") + return v + + def bold_format(self, v): + if self.color: + return color(v, fg="white") + return v + + def good_format(self, v): + if self.color: + return color(v, fg="green") + return v + + def warning_format(self, v): + if self.color: + return color(v, fg="yellow") + return v + + def critical_format(self, v): + if self.color: + return color(v, fg="red") + return v + + def output_format(self, v): + if self.color: + return color(v, fg="cyan") + return v + + def calculate_planner_estimate(self, plan): + plan["Planner Row Estimate Factor"] = 0 + plan["Planner Row Estimate Direction"] = "Under" + + if plan["Plan Rows"] == plan["Actual Rows"]: + return plan + + if plan["Plan Rows"] != 0: + plan["Planner Row Estimate Factor"] = ( + plan["Actual Rows"] / plan["Plan Rows"] + ) + + if plan["Planner Row Estimate Factor"] < 10: + plan["Planner Row Estimate Factor"] = 0 + plan["Planner Row Estimate Direction"] = "Over" + if plan["Actual Rows"] != 0: + plan["Planner Row Estimate Factor"] = ( + plan["Plan Rows"] / plan["Actual Rows"] + ) + return plan + + # + def calculate_actuals(self, plan): + plan["Actual Duration"] = plan["Actual Total Time"] + plan["Actual Cost"] = plan["Total Cost"] + + for child in plan.get("Plans", []): + if child["Node Type"] != "CTEScan": + plan["Actual Duration"] = ( + plan["Actual Duration"] - child["Actual Total Time"] + ) + plan["Actual Cost"] = plan["Actual Cost"] - child["Total Cost"] + + if plan["Actual Cost"] < 0: + plan["Actual Cost"] = 0 + + plan["Actual Duration"] = plan["Actual Duration"] * plan["Actual Loops"] + return plan + + def calculate_outlier_nodes(self, plan): + plan["Costliest"] = plan["Actual Cost"] == self.explain["Max Cost"] + plan["Largest"] = plan["Actual Rows"] == self.explain["Max Rows"] + plan["Slowest"] = plan["Actual Duration"] == self.explain["Max Duration"] + + for index in range(len(plan.get("Plans", []))): + _plan = plan["Plans"][index] + plan["Plans"][index] = self.calculate_outlier_nodes(_plan) + return plan + + def calculate_maximums(self, plan): + if not self.explain.get("Max Rows"): + self.explain["Max Rows"] = plan["Actual Rows"] + elif self.explain.get("Max Rows") < plan["Actual Rows"]: + self.explain["Max Rows"] = plan["Actual Rows"] + + if not self.explain.get("Max Cost"): + self.explain["Max Cost"] = plan["Actual Cost"] + elif self.explain.get("Max Cost") < plan["Actual Cost"]: + self.explain["Max Cost"] = plan["Actual Cost"] + + if not self.explain.get("Max Duration"): + self.explain["Max Duration"] = plan["Actual Duration"] + elif self.explain.get("Max Duration") < plan["Actual Duration"]: + self.explain["Max Duration"] = plan["Actual Duration"] + + if not self.explain.get("Total Cost"): + self.explain["Total Cost"] = plan["Actual Cost"] + elif self.explain.get("Total Cost") < plan["Actual Cost"]: + self.explain["Total Cost"] = plan["Actual Cost"] + + # + def duration_to_string(self, value): + if value < 1: + return self.good_format("<1 ms") + elif value < 100: + return self.good_format("%.2f ms" % value) + elif value < 1000: + return self.warning_format("%.2f ms" % value) + elif value < 60000: + return self.critical_format( + "%.2f s" % (value / 1000.0), + ) + else: + return self.critical_format( + "%.2f m" % (value / 60000.0), + ) + + # } + # + def format_details(self, plan): + details = [] + + if plan.get("Scan Direction"): + details.append(plan["Scan Direction"]) + + if plan.get("Strategy"): + details.append(plan["Strategy"]) + + if len(details) > 0: + return self.muted_format(" [%s]" % ", ".join(details)) + + return "" + + def format_tags(self, plan): + tags = [] + + if plan["Slowest"]: + tags.append(self.tag_format("slowest")) + if plan["Costliest"]: + tags.append(self.tag_format("costliest")) + if plan["Largest"]: + tags.append(self.tag_format("largest")) + if plan.get("Planner Row Estimate Factor", 0) >= 100: + tags.append(self.tag_format("bad estimate")) + + return " ".join(tags) + + def get_terminator(self, index, plan): + if index == 0: + if len(plan.get("Plans", [])) == 0: + return "⌡► " + else: + return "├► " + else: + if len(plan.get("Plans", [])) == 0: + return " " + else: + return "│ " + + def wrap_string(self, line, width): + if width == 0: + return [line] + return textwrap.wrap(line, width) + + def intcomma(self, value): + sep = "," + if not isinstance(value, str): + value = int(value) + + orig = str(value) + + new = re.sub(r"^(-?\d+)(\d{3})", rf"\g<1>{sep}\g<2>", orig) + if orig == new: + return new + else: + return self.intcomma(new) + + def output_fn(self, current_prefix, string): + return "%s%s" % (self.prefix_format(current_prefix), string) + + def create_lines(self, plan, prefix, depth, width, last_child): + current_prefix = prefix + self.string_lines.append( + self.output_fn(current_prefix, self.prefix_format("│")) + ) + + joint = "├" + if last_child: + joint = "└" + # + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s%s %s" + % ( + self.prefix_format(joint + "─⌠"), + self.bold_format(plan["Node Type"]), + self.format_details(plan), + self.format_tags(plan), + ), + ) + ) + # + if last_child: + prefix += " " + else: + prefix += "│ " + + current_prefix = prefix + "│ " + + cols = width - len(current_prefix) + + for line in self.wrap_string( + DESCRIPTIONS.get(plan["Node Type"], "Not found : %s" % plan["Node Type"]), + cols, + ): + self.string_lines.append( + self.output_fn(current_prefix, "%s" % self.muted_format(line)) + ) + # + if plan.get("Actual Duration"): + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s (%.0f%%)" + % ( + "Duration:", + self.duration_to_string(plan["Actual Duration"]), + (plan["Actual Duration"] / self.explain["Execution Time"]) + * 100, + ), + ) + ) + + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s (%.0f%%)" + % ( + "Cost:", + self.intcomma(plan["Actual Cost"]), + (plan["Actual Cost"] / self.explain["Total Cost"]) * 100, + ), + ) + ) + + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s" % ("Rows:", self.intcomma(plan["Actual Rows"])), + ) + ) + + current_prefix = current_prefix + " " + + if plan.get("Join Type"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (plan["Join Type"], self.muted_format("join")), + ) + ) + + if plan.get("Relation Name"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s.%s" + % ( + self.muted_format("on"), + plan.get("Schema", "unknown"), + plan["Relation Name"], + ), + ) + ) + + if plan.get("Index Name"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("using"), plan["Index Name"]), + ) + ) + + if plan.get("Index Condition"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("condition"), plan["Index Condition"]), + ) + ) + + if plan.get("Filter"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s %s" + % ( + self.muted_format("filter"), + plan["Filter"], + self.muted_format( + "[-%s rows]" % self.intcomma(plan["Rows Removed by Filter"]) + ), + ), + ) + ) + + if plan.get("Hash Condition"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("on"), plan["Hash Condition"]), + ) + ) + + if plan.get("CTE Name"): + self.string_lines.append( + self.output_fn(current_prefix, "CTE %s" % plan["CTE Name"]) + ) + + if plan.get("Planner Row Estimate Factor") != 0: + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %sestimated %s %.2fx" + % ( + self.muted_format("rows"), + plan["Planner Row Estimate Direction"], + self.muted_format("by"), + plan["Planner Row Estimate Factor"], + ), + ) + ) + + current_prefix = prefix + + if len(plan.get("Output", [])) > 0: + for index, line in enumerate( + self.wrap_string(" + ".join(plan["Output"]), cols) + ): + self.string_lines.append( + self.output_fn( + current_prefix, + self.prefix_format(self.get_terminator(index, plan)) + + self.output_format(line), + ) + ) + + for index, nested_plan in enumerate(plan.get("Plans", [])): + self.create_lines( + nested_plan, prefix, depth + 1, width, index == len(plan["Plans"]) - 1 + ) + + def generate_lines(self): + self.string_lines = [ + "○ Total Cost: %s" % self.intcomma(self.explain["Total Cost"]), + "○ Planning Time: %s" + % self.duration_to_string(self.explain["Planning Time"]), + "○ Execution Time: %s" + % self.duration_to_string(self.explain["Execution Time"]), + self.prefix_format("┬"), + ] + self.create_lines( + self.plan, + "", + 0, + self.terminal_width, + len(self.plan.get("Plans", [])) == 1, + ) + + def get_list(self): + return "\n".join(self.string_lines) + + def print(self): + for lin in self.string_lines: + print(lin) diff --git a/post-install b/post-install new file mode 100644 index 0000000..d516a3f --- /dev/null +++ b/post-install @@ -0,0 +1,4 @@ +#!/bin/bash + +echo "Setting up symlink to pgcli" +ln -sf /usr/share/pgcli/bin/pgcli /usr/local/bin/pgcli diff --git a/post-remove b/post-remove new file mode 100644 index 0000000..1013eb4 --- /dev/null +++ b/post-remove @@ -0,0 +1,4 @@ +#!/bin/bash + +echo "Removing symlink to pgcli" +rm /usr/local/bin/pgcli diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..4aa448d --- /dev/null +++ b/pylintrc @@ -0,0 +1,2 @@ +[MESSAGES CONTROL] +disable=missing-docstring,invalid-name \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8477d72 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[tool.black] +line-length = 88 +target-version = ['py38'] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | \.cache + | \.pytest_cache + | _build + | buck-out + | build + | dist + | tests/data +)/ +''' diff --git a/release.py b/release.py new file mode 100644 index 0000000..42a72a9 --- /dev/null +++ b/release.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +"""A script to publish a release of pgcli to PyPI.""" + +import io +from optparse import OptionParser +import re +import subprocess +import sys + +import click + +DEBUG = False +CONFIRM_STEPS = False +DRY_RUN = False + + +def skip_step(): + """ + Asks for user's response whether to run a step. Default is yes. + :return: boolean + """ + global CONFIRM_STEPS + + if CONFIRM_STEPS: + return not click.confirm("--- Run this step?", default=True) + return False + + +def run_step(*args): + """ + Prints out the command and asks if it should be run. + If yes (default), runs it. + :param args: list of strings (command and args) + """ + global DRY_RUN + + cmd = args + print(" ".join(cmd)) + if skip_step(): + print("--- Skipping...") + elif DRY_RUN: + print("--- Pretending to run...") + else: + subprocess.check_output(cmd) + + +def version(version_file): + _version_re = re.compile( + r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)' + ) + + with io.open(version_file, encoding="utf-8") as f: + ver = _version_re.search(f.read()).group("version") + + return ver + + +def commit_for_release(version_file, ver): + run_step("git", "reset") + run_step("git", "add", "-u") + run_step("git", "commit", "--message", "Releasing version {}".format(ver)) + + +def create_git_tag(tag_name): + run_step("git", "tag", tag_name) + + +def create_distribution_files(): + run_step("python", "setup.py", "clean", "--all", "sdist", "bdist_wheel") + + +def upload_distribution_files(): + run_step("twine", "upload", "dist/*") + + +def push_to_github(): + run_step("git", "push", "origin", "main") + + +def push_tags_to_github(): + run_step("git", "push", "--tags", "origin") + + +def checklist(questions): + for question in questions: + if not click.confirm("--- {}".format(question), default=False): + sys.exit(1) + + +if __name__ == "__main__": + if DEBUG: + subprocess.check_output = lambda x: x + + checks = [ + "Have you updated the AUTHORS file?", + "Have you updated the `Usage` section of the README?", + ] + checklist(checks) + + ver = version("pgcli/__init__.py") + print("Releasing Version:", ver) + + parser = OptionParser() + parser.add_option( + "-c", + "--confirm-steps", + action="store_true", + dest="confirm_steps", + default=False, + help=( + "Confirm every step. If the step is not " "confirmed, it will be skipped." + ), + ) + parser.add_option( + "-d", + "--dry-run", + action="store_true", + dest="dry_run", + default=False, + help="Print out, but not actually run any steps.", + ) + + popts, pargs = parser.parse_args() + CONFIRM_STEPS = popts.confirm_steps + DRY_RUN = popts.dry_run + + if not click.confirm("Are you sure?", default=False): + sys.exit(1) + + commit_for_release("pgcli/__init__.py", ver) + create_git_tag("v{}".format(ver)) + create_distribution_files() + push_to_github() + push_tags_to_github() + upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..15505a7 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,13 @@ +pytest>=2.7.0 +tox>=1.9.2 +behave>=1.2.4 +black>=23.3.0 +pexpect==3.3; platform_system != "Windows" +pre-commit>=1.16.0 +coverage>=5.0.4 +codecov>=1.5.1 +docutils>=0.13.1 +autopep8>=1.3.3 +twine>=1.11.0 +wheel>=0.33.6 +sshtunnel>=0.4.0 diff --git a/sanity_checks.txt b/sanity_checks.txt new file mode 100644 index 0000000..d8a4898 --- /dev/null +++ b/sanity_checks.txt @@ -0,0 +1,37 @@ +# vi: ft=vimwiki + +* Launch pgcli with different inputs. + * pgcli test_db + * pgcli postgres://localhost/test_db + * pgcli postgres://localhost:5432/test_db + * pgcli postgres://amjith@localhost:5432/test_db + * pgcli postgres://amjith:password@localhost:5432/test_db + * pgcli non-existent-db + +* Test special command + * \d + * \d table_name + * \dt + * \l + * \c amjith + * \q + +* Simple execution: + 1 Execute a simple 'select * from users;' test that will pass. + 2 Execute a syntax error: 'insert into users ( ;' + 3 Execute a simple test from step 1 again to see if it still passes. + * Change the database and try steps 1 - 3. + +* Test smart-completion + * Sele - Must auto-complete to SELECT + * SELECT * FROM - Must list the table names. + * INSERT INTO - Must list table names. + * \d - Must list table names. + * \c - Database names. + * SELECT * FROM table_name WHERE - column names (all of it). + +* Test naive-completion - turn off smart completion (using F2 key after launch) + * Sele - autocomplete to select. + * SELECT * FROM - autocomplete list should have everything. + + diff --git a/screenshots/image01.png b/screenshots/image01.png new file mode 100644 index 0000000..58520c5 Binary files /dev/null and b/screenshots/image01.png differ diff --git a/screenshots/image02.png b/screenshots/image02.png new file mode 100644 index 0000000..c321c86 Binary files /dev/null and b/screenshots/image02.png differ diff --git a/screenshots/kharkiv-destroyed.jpg b/screenshots/kharkiv-destroyed.jpg new file mode 100644 index 0000000..4f95783 Binary files /dev/null and b/screenshots/kharkiv-destroyed.jpg differ diff --git a/screenshots/pgcli.gif b/screenshots/pgcli.gif new file mode 100644 index 0000000..9c8e66d Binary files /dev/null and b/screenshots/pgcli.gif differ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..9a398a4 --- /dev/null +++ b/setup.py @@ -0,0 +1,76 @@ +import platform +from setuptools import setup, find_packages + +from pgcli import __version__ + +description = "CLI for Postgres Database. With auto-completion and syntax highlighting." + +install_requirements = [ + "pgspecial>=2.0.0", + "click >= 4.1", + "Pygments>=2.0", # Pygments has to be Capitalcased. WTF? + # We still need to use pt-2 unless pt-3 released on Fedora32 + # see: https://github.com/dbcli/pgcli/pull/1197 + "prompt_toolkit>=2.0.6,<4.0.0", + "psycopg >= 3.0.14", + "sqlparse >=0.3.0,<0.5", + "configobj >= 5.0.6", + "pendulum>=2.1.0", + "cli_helpers[styles] >= 2.2.1", +] + + +# setproctitle is used to mask the password when running `ps` in command line. +# But this is not necessary in Windows since the password is never shown in the +# task manager. Also setproctitle is a hard dependency to install in Windows, +# so we'll only install it if we're not in Windows. +if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"): + install_requirements.append("setproctitle >= 1.1.9") + +# Windows will require the binary psycopg to run pgcli +if platform.system() == "Windows": + install_requirements.append("psycopg-binary >= 3.0.14") + + +setup( + name="pgcli", + author="Pgcli Core Team", + author_email="pgcli-dev@googlegroups.com", + version=__version__, + license="BSD", + url="http://pgcli.com", + packages=find_packages(), + package_data={"pgcli": ["pgclirc", "packages/pgliterals/pgliterals.json"]}, + description=description, + long_description=open("README.rst").read(), + install_requires=install_requirements, + dependency_links=[ + "http://github.com/psycopg/repo/tarball/master#egg=psycopg-3.0.10" + ], + extras_require={ + "keyring": ["keyring >= 12.2.0"], + "sshtunnel": ["sshtunnel >= 0.4.0"], + }, + python_requires=">=3.8", + entry_points=""" + [console_scripts] + pgcli=pgcli.main:cli + """, + classifiers=[ + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..33cddf2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,52 @@ +import os +import pytest +from utils import ( + POSTGRES_HOST, + POSTGRES_PORT, + POSTGRES_USER, + POSTGRES_PASSWORD, + create_db, + db_connection, + drop_tables, +) +import pgcli.pgexecute + + +@pytest.fixture(scope="function") +def connection(): + create_db("_test_db") + connection = db_connection("_test_db") + yield connection + + drop_tables(connection) + connection.close() + + +@pytest.fixture +def cursor(connection): + with connection.cursor() as cur: + return cur + + +@pytest.fixture +def executor(connection): + return pgcli.pgexecute.PGExecute( + database="_test_db", + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + dsn=None, + ) + + +@pytest.fixture +def exception_formatter(): + return lambda e: str(e) + + +@pytest.fixture(scope="session", autouse=True) +def temp_config(tmpdir_factory): + # this function runs on start of test session. + # use temporary directory for config home so user config will not be used + os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data")) diff --git a/tests/features/__init__.py b/tests/features/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/features/auto_vertical.feature b/tests/features/auto_vertical.feature new file mode 100644 index 0000000..aa95718 --- /dev/null +++ b/tests/features/auto_vertical.feature @@ -0,0 +1,12 @@ +Feature: auto_vertical mode: + on, off + + Scenario: auto_vertical on with small query + When we run dbcli with --auto-vertical-output + and we execute a small query + then we see small results in horizontal format + + Scenario: auto_vertical on with large query + When we run dbcli with --auto-vertical-output + and we execute a large query + then we see large results in vertical format diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature new file mode 100644 index 0000000..ee497b9 --- /dev/null +++ b/tests/features/basic_commands.feature @@ -0,0 +1,81 @@ +Feature: run the cli, + call the help command, + exit the cli + + Scenario: run "\?" command + When we send "\?" command + then we see help output + + Scenario: run source command + When we send source command + then we see help output + + Scenario: run partial select command + When we send partial select command + then we see error message + then we see dbcli prompt + + Scenario: check our application_name + When we run query to check application_name + then we see found + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits + + Scenario: confirm exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "c" + then dbcli exits + + Scenario: cancel exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "a" + then we see dbcli prompt + when we rollback transaction + when we send "ctrl + d" + then dbcli exits + + Scenario: interrupt current query via "ctrl + c" + When we send sleep query + and we send "ctrl + c" + then we see cancelled query warning + when we check for any non-idle sleep queries + then we don't see any non-idle sleep queries + + Scenario: list databases + When we list databases + then we see list of databases + + Scenario: run the cli with --username + When we launch dbcli using --username + and we send "\?" command + then we see help output + + Scenario: run the cli with --user + When we launch dbcli using --user + and we send "\?" command + then we see help output + + Scenario: run the cli with --port + When we launch dbcli using --port + and we send "\?" command + then we see help output + + Scenario: run the cli with --password + When we launch dbcli using --password + then we send password + and we see dbcli prompt + when we send "\?" command + then we see help output + + Scenario: run the cli with dsn and password + When we launch dbcli using dsn_password + then we send password + and we see dbcli prompt + when we send "\?" command + then we see help output diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature new file mode 100644 index 0000000..87da4e3 --- /dev/null +++ b/tests/features/crud_database.feature @@ -0,0 +1,17 @@ +Feature: manipulate databases: + create, drop, connect, disconnect + + Scenario: create and drop temporary database + When we create database + then we see database created + when we drop database + then we respond to the destructive warning: y + then we see database dropped + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from test database + When we connect to test database + then we see database connected + when we connect to dbserver + then we see database connected diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature new file mode 100644 index 0000000..8a43c5c --- /dev/null +++ b/tests/features/crud_table.feature @@ -0,0 +1,45 @@ +Feature: manipulate tables: + create, insert, update, select, delete from, drop + + Scenario: create, insert, select from, update, drop table + When we connect to test database + then we see database connected + when we create table + then we see table created + when we insert into table + then we see record inserted + when we select from table + then we see data selected: initial + when we update table + then we see record updated + when we select from table + then we see data selected: updated + when we delete from table + then we respond to the destructive warning: y + then we see record deleted + when we drop table + then we respond to the destructive warning: y + then we see table dropped + when we connect to dbserver + then we see database connected + + Scenario: transaction handling, with cancelling on a destructive warning. + When we connect to test database + then we see database connected + when we create table + then we see table created + when we begin transaction + then we see transaction began + when we insert into table + then we see record inserted + when we delete from table + then we respond to the destructive warning: n + when we select from table + then we see data selected: initial + when we rollback transaction + then we see transaction rolled back + when we select from table + then we see select output without data + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py new file mode 100644 index 0000000..595c6c2 --- /dev/null +++ b/tests/features/db_utils.py @@ -0,0 +1,87 @@ +from psycopg import connect + + +def create_db( + hostname="localhost", username=None, password=None, dbname=None, port=None +): + """Create test database. + + :param hostname: string + :param username: string + :param password: string + :param dbname: string + :param port: int + :return: + + """ + cn = create_cn(hostname, password, username, "postgres", port) + + cn.autocommit = True + with cn.cursor() as cr: + cr.execute(f"drop database if exists {dbname}") + cr.execute(f"create database {dbname}") + + cn.close() + + cn = create_cn(hostname, password, username, dbname, port) + return cn + + +def create_cn(hostname, password, username, dbname, port): + """ + Open connection to database. + :param hostname: + :param password: + :param username: + :param dbname: string + :return: psycopg2.connection + """ + cn = connect( + host=hostname, user=username, dbname=dbname, password=password, port=port + ) + + print(f"Created connection: {cn.info.get_parameters()}.") + return cn + + +def pgbouncer_available(hostname="localhost", password=None, username="postgres"): + cn = None + try: + cn = create_cn(hostname, password, username, "pgbouncer", 6432) + return True + except: + print("Pgbouncer is not available.") + finally: + if cn: + cn.close() + return False + + +def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None): + """ + Drop database. + :param hostname: string + :param username: string + :param password: string + :param dbname: string + """ + cn = create_cn(hostname, password, username, "postgres", port) + + # Needed for DB drop. + cn.autocommit = True + + with cn.cursor() as cr: + cr.execute(f"drop database if exists {dbname}") + + close_cn(cn) + + +def close_cn(cn=None): + """ + Close connection. + :param connection: psycopg2.connection + """ + if cn: + cn_params = cn.info.get_parameters() + cn.close() + print(f"Closed connection: {cn_params}.") diff --git a/tests/features/environment.py b/tests/features/environment.py new file mode 100644 index 0000000..50ac5fa --- /dev/null +++ b/tests/features/environment.py @@ -0,0 +1,227 @@ +import copy +import os +import sys +import db_utils as dbutils +import fixture_utils as fixutils +import pexpect +import tempfile +import shutil +import signal + + +from steps import wrappers + + +def before_all(context): + """Set env parameters.""" + env_old = copy.deepcopy(dict(os.environ)) + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["PAGER"] = "cat" + os.environ["EDITOR"] = "ex" + os.environ["VISUAL"] = "ex" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + + context.package_root = os.path.abspath( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + ) + fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data") + + print("package root:", context.package_root) + print("fixture dir:", fixture_dir) + + os.environ["COVERAGE_PROCESS_START"] = os.path.join( + context.package_root, ".coveragerc" + ) + + context.exit_sent = False + + vi = "_".join([str(x) for x in sys.version_info[:3]]) + db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests") + db_name_full = f"{db_name}_{vi}" + + # Store get params from config. + context.conf = { + "host": context.config.userdata.get( + "pg_test_host", os.getenv("PGHOST", "localhost") + ), + "user": context.config.userdata.get( + "pg_test_user", os.getenv("PGUSER", "postgres") + ), + "pass": context.config.userdata.get( + "pg_test_pass", os.getenv("PGPASSWORD", None) + ), + "port": context.config.userdata.get( + "pg_test_port", os.getenv("PGPORT", "5432") + ), + "cli_command": ( + context.config.userdata.get("pg_cli_command", None) + or '{python} -c "{startup}"'.format( + python=sys.executable, + startup="; ".join( + [ + "import coverage", + "coverage.process_startup()", + "import pgcli.main", + "pgcli.main.cli(auto_envvar_prefix='BEHAVE')", + ] + ), + ) + ), + "dbname": db_name_full, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", + } + os.environ["PAGER"] = "{0} {1} {2}".format( + sys.executable, + os.path.join(context.package_root, "tests/features/wrappager.py"), + context.conf["pager_boundary"], + ) + + # Store old env vars. + context.pgenv = { + "PGDATABASE": os.environ.get("PGDATABASE", None), + "PGUSER": os.environ.get("PGUSER", None), + "PGHOST": os.environ.get("PGHOST", None), + "PGPASSWORD": os.environ.get("PGPASSWORD", None), + "PGPORT": os.environ.get("PGPORT", None), + "XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None), + "PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None), + } + + # Set new env vars. + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGUSER"] = context.conf["user"] + os.environ["PGHOST"] = context.conf["host"] + os.environ["PGPORT"] = context.conf["port"] + os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf") + + if context.conf["pass"]: + os.environ["PGPASSWORD"] = context.conf["pass"] + else: + if "PGPASSWORD" in os.environ: + del os.environ["PGPASSWORD"] + os.environ["BEHAVE_WARN"] = "moderate" + + context.cn = dbutils.create_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + context.pgbouncer_available = dbutils.pgbouncer_available( + hostname=context.conf["host"], + password=context.conf["pass"], + username=context.conf["user"], + ) + context.fixture_data = fixutils.read_fixture_files() + + # use temporary directory as config home + context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_") + os.environ["XDG_CONFIG_HOME"] = context.env_config_home + show_env_changes(env_old, dict(os.environ)) + + +def show_env_changes(env_old, env_new): + """Print out all test-specific env values.""" + print("--- os.environ changed values: ---") + all_keys = env_old.keys() | env_new.keys() + for k in sorted(all_keys): + old_value = env_old.get(k, "") + new_value = env_new.get(k, "") + if new_value and old_value != new_value: + print(f'{k}="{new_value}"') + print("-" * 20) + + +def after_all(context): + """ + Unset env parameters. + """ + dbutils.close_cn(context.cn) + dbutils.drop_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + + # Remove temp config directory + shutil.rmtree(context.env_config_home) + + # Restore env vars. + for k, v in context.pgenv.items(): + if k in os.environ and v is None: + del os.environ[k] + elif v: + os.environ[k] = v + + +def before_step(context, _): + context.atprompt = False + + +def is_known_problem(scenario): + """TODO: why is this not working in 3.12?""" + if sys.version_info >= (3, 12): + return scenario.name in ( + 'interrupt current query via "ctrl + c"', + "run the cli with --username", + "run the cli with --user", + "run the cli with --port", + ) + return False + + +def before_scenario(context, scenario): + if scenario.name == "list databases": + # not using the cli for that + return + if is_known_problem(scenario): + scenario.skip() + currentdb = None + if "pgbouncer" in scenario.feature.tags: + if context.pgbouncer_available: + os.environ["PGDATABASE"] = "pgbouncer" + os.environ["PGPORT"] = "6432" + currentdb = "pgbouncer" + else: + scenario.skip() + else: + # set env vars back to normal test database + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGPORT"] = context.conf["port"] + wrappers.run_cli(context, currentdb=currentdb) + wrappers.wait_prompt(context) + + +def after_scenario(context, scenario): + """Cleans up after each scenario completes.""" + if hasattr(context, "cli") and context.cli and not context.exit_sent: + # Quit nicely. + if not getattr(context, "atprompt", False): + dbname = context.currentdb + context.cli.expect_exact(f"{dbname}>", timeout=5) + try: + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") + except Exception as x: + print("Failed cleanup after scenario:") + print(x) + try: + context.cli.expect_exact(pexpect.EOF, timeout=5) + except pexpect.TIMEOUT: + print(f"--- after_scenario {scenario.name}: kill cli") + context.cli.kill(signal.SIGKILL) + if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help: + context.tmpfile_sql_help.close() + context.tmpfile_sql_help = None + + +# # TODO: uncomment to debug a failure +# def after_step(context, step): +# if step.status == "failed": +# import pdb; pdb.set_trace() diff --git a/tests/features/expanded.feature b/tests/features/expanded.feature new file mode 100644 index 0000000..e486048 --- /dev/null +++ b/tests/features/expanded.feature @@ -0,0 +1,29 @@ +Feature: expanded mode: + on, off, auto + + Scenario: expanded on + When we prepare the test data + and we set expanded on + and we select from table + then we see expanded data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped + + Scenario: expanded off + When we prepare the test data + and we set expanded off + and we select from table + then we see nonexpanded data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped + + Scenario: expanded auto + When we prepare the test data + and we set expanded auto + and we select from table + then we see auto data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/fixture_data/help.txt b/tests/features/fixture_data/help.txt new file mode 100644 index 0000000..bebb976 --- /dev/null +++ b/tests/features/fixture_data/help.txt @@ -0,0 +1,25 @@ ++--------------------------+------------------------------------------------+ +| Command | Description | +|--------------------------+------------------------------------------------| +| \# | Refresh auto-completions. | +| \? | Show Help. | +| \T [format] | Change the table format used to output results | +| \c[onnect] database_name | Change to a new database. | +| \d [pattern] | List or describe tables, views and sequences. | +| \dT[S+] [pattern] | List data types | +| \df[+] [pattern] | List functions. | +| \di[+] [pattern] | List indexes. | +| \dn[+] [pattern] | List schemas. | +| \ds[+] [pattern] | List sequences. | +| \dt[+] [pattern] | List tables. | +| \du[+] [pattern] | List roles. | +| \dv[+] [pattern] | List views. | +| \e [file] | Edit the query with external editor. | +| \l | List databases. | +| \n[+] [name] | List or execute named queries. | +| \nd [name [query]] | Delete a named query. | +| \ns name query | Save a named query. | +| \refresh | Refresh auto-completions. | +| \timing | Toggle timing of commands. | +| \x | Toggle expanded output. | ++--------------------------+------------------------------------------------+ diff --git a/tests/features/fixture_data/help_commands.txt b/tests/features/fixture_data/help_commands.txt new file mode 100644 index 0000000..e076661 --- /dev/null +++ b/tests/features/fixture_data/help_commands.txt @@ -0,0 +1,64 @@ +Command +Description +\# +Refresh auto-completions. +\? +Show Commands. +\T [format] +Change the table format used to output results +\c[onnect] database_name +Change to a new database. +\copy [tablename] to/from [filename] +Copy data between a file and a table. +\d[+] [pattern] +List or describe tables, views and sequences. +\dT[S+] [pattern] +List data types +\db[+] [pattern] +List tablespaces. +\df[+] [pattern] +List functions. +\di[+] [pattern] +List indexes. +\dm[+] [pattern] +List materialized views. +\dn[+] [pattern] +List schemas. +\ds[+] [pattern] +List sequences. +\dt[+] [pattern] +List tables. +\du[+] [pattern] +List roles. +\dv[+] [pattern] +List views. +\dx[+] [pattern] +List extensions. +\e [file] +Edit the query with external editor. +\h +Show SQL syntax and help. +\i filename +Execute commands from file. +\l +List databases. +\n[+] [name] [param1 param2 ...] +List or execute named queries. +\nd [name] +Delete a named query. +\ns name query +Save a named query. +\o [filename] +Send all query results to file. +\pager [command] +Set PAGER. Print the query results via PAGER. +\pset [key] [value] +A limited version of traditional \pset +\refresh +Refresh auto-completions. +\sf[+] FUNCNAME +Show a function's definition. +\timing +Toggle timing of commands. +\x +Toggle expanded output. diff --git a/tests/features/fixture_data/mock_pg_service.conf b/tests/features/fixture_data/mock_pg_service.conf new file mode 100644 index 0000000..15f9811 --- /dev/null +++ b/tests/features/fixture_data/mock_pg_service.conf @@ -0,0 +1,4 @@ +[mock_postgres] +dbname=postgres +host=localhost +user=postgres diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py new file mode 100644 index 0000000..70b603d --- /dev/null +++ b/tests/features/fixture_utils.py @@ -0,0 +1,28 @@ +import os +import codecs + + +def read_fixture_lines(filename): + """ + Read lines of text from file. + :param filename: string name + :return: list of strings + """ + lines = [] + for line in codecs.open(filename, "rb", encoding="utf-8"): + lines.append(line.strip()) + return lines + + +def read_fixture_files(): + """Read all files inside fixture_data directory.""" + current_dir = os.path.dirname(__file__) + fixture_dir = os.path.join(current_dir, "fixture_data/") + print(f"reading fixture data: {fixture_dir}") + fixture_dict = {} + for filename in os.listdir(fixture_dir): + if filename not in [".", ".."]: + fullname = os.path.join(fixture_dir, filename) + fixture_dict[filename] = read_fixture_lines(fullname) + + return fixture_dict diff --git a/tests/features/iocommands.feature b/tests/features/iocommands.feature new file mode 100644 index 0000000..dad7d10 --- /dev/null +++ b/tests/features/iocommands.feature @@ -0,0 +1,17 @@ +Feature: I/O commands + + Scenario: edit sql in file with external editor + When we start external editor providing a file name + and we type sql in the editor + and we exit the editor + then we see dbcli prompt + and we see the sql in prompt + + Scenario: tee output from query + When we tee output + and we wait for prompt + and we query "select 123456" + and we wait for prompt + and we stop teeing output + and we wait for prompt + then we see 123456 in tee output diff --git a/tests/features/named_queries.feature b/tests/features/named_queries.feature new file mode 100644 index 0000000..74201b9 --- /dev/null +++ b/tests/features/named_queries.feature @@ -0,0 +1,10 @@ +Feature: named queries: + save, use and delete named queries + + Scenario: save, use and delete named queries + When we connect to test database + then we see database connected + when we save a named query + then we see the named query saved + when we delete a named query + then we see the named query deleted diff --git a/tests/features/pgbouncer.feature b/tests/features/pgbouncer.feature new file mode 100644 index 0000000..14cc5ad --- /dev/null +++ b/tests/features/pgbouncer.feature @@ -0,0 +1,12 @@ +@pgbouncer +Feature: run pgbouncer, + call the help command, + exit the cli + + Scenario: run "show help" command + When we send "show help" command + then we see the pgbouncer help output + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits diff --git a/tests/features/specials.feature b/tests/features/specials.feature new file mode 100644 index 0000000..63c5cdc --- /dev/null +++ b/tests/features/specials.feature @@ -0,0 +1,6 @@ +Feature: Special commands + + Scenario: run refresh command + When we refresh completions + and we wait for prompt + then we see completions refresh started diff --git a/tests/features/steps/__init__.py b/tests/features/steps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/features/steps/auto_vertical.py b/tests/features/steps/auto_vertical.py new file mode 100644 index 0000000..d7cdccd --- /dev/null +++ b/tests/features/steps/auto_vertical.py @@ -0,0 +1,99 @@ +from textwrap import dedent +from behave import then, when +import wrappers + + +@when("we run dbcli with {arg}") +def step_run_cli_with_arg(context, arg): + wrappers.run_cli(context, run_args=arg.split("=")) + + +@when("we execute a small query") +def step_execute_small_query(context): + context.cli.sendline("select 1") + + +@when("we execute a large query") +def step_execute_large_query(context): + context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)]))) + + +@then("we see small results in horizontal format") +def step_see_small_results(context): + wrappers.expect_pager( + context, + dedent( + """\ + +----------+\r + | ?column? |\r + |----------|\r + | 1 |\r + +----------+\r + SELECT 1\r + """ + ), + timeout=5, + ) + + +@then("we see large results in vertical format") +def step_see_large_results(context): + wrappers.expect_pager( + context, + dedent( + """\ + -[ RECORD 1 ]-------------------------\r + ?column? | 1\r + ?column? | 2\r + ?column? | 3\r + ?column? | 4\r + ?column? | 5\r + ?column? | 6\r + ?column? | 7\r + ?column? | 8\r + ?column? | 9\r + ?column? | 10\r + ?column? | 11\r + ?column? | 12\r + ?column? | 13\r + ?column? | 14\r + ?column? | 15\r + ?column? | 16\r + ?column? | 17\r + ?column? | 18\r + ?column? | 19\r + ?column? | 20\r + ?column? | 21\r + ?column? | 22\r + ?column? | 23\r + ?column? | 24\r + ?column? | 25\r + ?column? | 26\r + ?column? | 27\r + ?column? | 28\r + ?column? | 29\r + ?column? | 30\r + ?column? | 31\r + ?column? | 32\r + ?column? | 33\r + ?column? | 34\r + ?column? | 35\r + ?column? | 36\r + ?column? | 37\r + ?column? | 38\r + ?column? | 39\r + ?column? | 40\r + ?column? | 41\r + ?column? | 42\r + ?column? | 43\r + ?column? | 44\r + ?column? | 45\r + ?column? | 46\r + ?column? | 47\r + ?column? | 48\r + ?column? | 49\r + SELECT 1\r + """ + ), + timeout=5, + ) diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py new file mode 100644 index 0000000..687bdc0 --- /dev/null +++ b/tests/features/steps/basic_commands.py @@ -0,0 +1,231 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +import pexpect +import subprocess +import tempfile + +from behave import when, then +from textwrap import dedent +import wrappers + + +@when("we list databases") +def step_list_databases(context): + cmd = ["pgcli", "--list"] + context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root) + + +@then("we see list of databases") +def step_see_list_databases(context): + assert b"List of databases" in context.cmd_output + assert b"postgres" in context.cmd_output + context.cmd_output = None + + +@when("we run dbcli") +def step_run_cli(context): + wrappers.run_cli(context) + + +@when("we launch dbcli using {arg}") +def step_run_cli_using_arg(context, arg): + prompt_check = False + currentdb = None + if arg == "--username": + arg = "--username={}".format(context.conf["user"]) + if arg == "--user": + arg = "--user={}".format(context.conf["user"]) + if arg == "--port": + arg = "--port={}".format(context.conf["port"]) + if arg == "--password": + arg = "--password" + prompt_check = False + # This uses the mock_pg_service.conf file in fixtures folder. + if arg == "dsn_password": + arg = "service=mock_postgres --password" + prompt_check = False + currentdb = "postgres" + wrappers.run_cli( + context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb + ) + + +@when("we wait for prompt") +def step_wait_prompt(context): + wrappers.wait_prompt(context) + + +@when('we send "ctrl + d"') +def step_ctrl_d(context): + """ + Send Ctrl + D to hopefully exit. + """ + step_try_to_ctrl_d(context) + context.cli.expect(pexpect.EOF, timeout=5) + context.exit_sent = True + + +@when('we try to send "ctrl + d"') +def step_try_to_ctrl_d(context): + """ + Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is + ongoing). + """ + # turn off pager before exiting + context.cli.sendcontrol("c") + context.cli.sendline(r"\pset pager off") + wrappers.wait_prompt(context) + context.cli.sendcontrol("d") + + +@when('we send "ctrl + c"') +def step_ctrl_c(context): + """Send Ctrl + c to hopefully interrupt.""" + context.cli.sendcontrol("c") + + +@then("we see cancelled query warning") +def step_see_cancelled_query_warning(context): + """ + Make sure we receive the warning that the current query was cancelled. + """ + wrappers.expect_exact(context, "cancelled query", timeout=2) + + +@then("we see ongoing transaction message") +def step_see_ongoing_transaction_error(context): + """ + Make sure we receive the warning that a transaction is ongoing. + """ + context.cli.expect("A transaction is ongoing.", timeout=2) + + +@when("we send sleep query") +def step_send_sleep_15_seconds(context): + """ + Send query to sleep for 15 seconds. + """ + context.cli.sendline("select pg_sleep(15)") + + +@when("we check for any non-idle sleep queries") +def step_check_for_active_sleep_queries(context): + """ + Send query to check for any non-idle pg_sleep queries. + """ + context.cli.sendline( + "select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';" + ) + + +@then("we don't see any non-idle sleep queries") +def step_no_active_sleep_queries(context): + """Confirm that any pg_sleep queries are either idle or not active.""" + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +-------+\r + | state |\r + |-------|\r + +-------+\r + SELECT 0\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) + + +@when(r'we send "\?" command') +def step_send_help(context): + r""" + Send \? to see help. + """ + context.cli.sendline(r"\?") + + +@when("we send partial select command") +def step_send_partial_select_command(context): + """ + Send `SELECT a` to see completion. + """ + context.cli.sendline("SELECT a") + + +@then("we see error message") +def step_see_error_message(context): + wrappers.expect_exact(context, 'column "a" does not exist', timeout=2) + + +@when("we send source command") +def step_send_source_command(context): + context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") + context.tmpfile_sql_help.write(rb"\?") + context.tmpfile_sql_help.flush() + context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}") + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + + +@when("we run query to check application_name") +def step_check_application_name(context): + context.cli.sendline( + "SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;" + ) + + +@then("we see found") +def step_see_found(context): + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +----------+\r + | ?column? |\r + |----------|\r + | found |\r + +----------+\r + SELECT 1\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) + + +@then("we respond to the destructive warning: {response}") +def step_resppond_to_destructive_command(context, response): + """Respond to destructive command.""" + wrappers.expect_exact( + context, + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", + timeout=2, + ) + context.cli.sendline(response.strip()) + + +@then("we send password") +def step_send_password(context): + wrappers.expect_exact(context, "Password for", timeout=5) + context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER") + + +@when('we send "{text}"') +def step_send_text(context, text): + context.cli.sendline(text) + # Try to detect whether we are exiting. If so, set `exit_sent` + # so that `after_scenario` correctly cleans up. + try: + context.cli.expect(pexpect.EOF, timeout=0.2) + except pexpect.TIMEOUT: + pass + else: + context.exit_sent = True diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py new file mode 100644 index 0000000..87cdc85 --- /dev/null +++ b/tests/features/steps/crud_database.py @@ -0,0 +1,93 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" +import pexpect + +from behave import when, then +import wrappers + + +@when("we create database") +def step_db_create(context): + """ + Send create database. + """ + context.cli.sendline("create database {};".format(context.conf["dbname_tmp"])) + + context.response = {"database_name": context.conf["dbname_tmp"]} + + +@when("we drop database") +def step_db_drop(context): + """ + Send drop database. + """ + context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"])) + + +@when("we connect to test database") +def step_db_connect_test(context): + """ + Send connect to database. + """ + db_name = context.conf["dbname"] + context.cli.sendline(f"\\connect {db_name}") + + +@when("we connect to dbserver") +def step_db_connect_dbserver(context): + """ + Send connect to database. + """ + context.cli.sendline("\\connect postgres") + context.currentdb = "postgres" + + +@then("dbcli exits") +def step_wait_exit(context): + """ + Make sure the cli exits. + """ + wrappers.expect_exact(context, pexpect.EOF, timeout=5) + + +@then("we see dbcli prompt") +def step_see_prompt(context): + """ + Wait to see the prompt. + """ + db_name = getattr(context, "currentdb", context.conf["dbname"]) + wrappers.expect_exact(context, f"{db_name}>", timeout=5) + context.atprompt = True + + +@then("we see help output") +def step_see_help(context): + for expected_line in context.fixture_data["help_commands.txt"]: + wrappers.expect_exact(context, expected_line, timeout=2) + + +@then("we see database created") +def step_see_db_created(context): + """ + Wait to see create database output. + """ + wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5) + + +@then("we see database dropped") +def step_see_db_dropped(context): + """ + Wait to see drop database output. + """ + wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2) + + +@then("we see database connected") +def step_see_db_connected(context): + """ + Wait to see drop database output. + """ + wrappers.expect_exact(context, "You are now connected to database", timeout=2) diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py new file mode 100644 index 0000000..27d543e --- /dev/null +++ b/tests/features/steps/crud_table.py @@ -0,0 +1,185 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +from textwrap import dedent +import wrappers + + +INITIAL_DATA = "xxx" +UPDATED_DATA = "yyy" + + +@when("we create table") +def step_create_table(context): + """ + Send create table. + """ + context.cli.sendline("create table a(x text);") + + +@when("we insert into table") +def step_insert_into_table(context): + """ + Send insert into table. + """ + context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""") + + +@when("we update table") +def step_update_table(context): + """ + Send insert into table. + """ + context.cli.sendline( + f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';""" + ) + + +@when("we select from table") +def step_select_from_table(context): + """ + Send select from table. + """ + context.cli.sendline("select * from a;") + + +@when("we delete from table") +def step_delete_from_table(context): + """ + Send deete from table. + """ + context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""") + + +@when("we drop table") +def step_drop_table(context): + """ + Send drop table. + """ + context.cli.sendline("drop table a;") + + +@when("we alter the table") +def step_alter_table(context): + """ + Alter the table by adding a column. + """ + context.cli.sendline("""alter table a add column y varchar;""") + + +@when("we begin transaction") +def step_begin_transaction(context): + """ + Begin transaction + """ + context.cli.sendline("begin;") + + +@when("we rollback transaction") +def step_rollback_transaction(context): + """ + Rollback transaction + """ + context.cli.sendline("rollback;") + + +@then("we see table created") +def step_see_table_created(context): + """ + Wait to see create table output. + """ + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) + + +@then("we see record inserted") +def step_see_record_inserted(context): + """ + Wait to see insert output. + """ + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) + + +@then("we see record updated") +def step_see_record_updated(context): + """ + Wait to see update output. + """ + wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2) + + +@then("we see data selected: {data}") +def step_see_data_selected(context, data): + """ + Wait to see select output with initial or updated data. + """ + x = UPDATED_DATA if data == "updated" else INITIAL_DATA + wrappers.expect_pager( + context, + dedent( + f"""\ + +-----+\r + | x |\r + |-----|\r + | {x} |\r + +-----+\r + SELECT 1\r + """ + ), + timeout=1, + ) + + +@then("we see select output without data") +def step_see_no_data_selected(context): + """ + Wait to see select output without data. + """ + wrappers.expect_pager( + context, + dedent( + """\ + +---+\r + | x |\r + |---|\r + +---+\r + SELECT 0\r + """ + ), + timeout=1, + ) + + +@then("we see record deleted") +def step_see_data_deleted(context): + """ + Wait to see delete output. + """ + wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2) + + +@then("we see table dropped") +def step_see_table_dropped(context): + """ + Wait to see drop output. + """ + wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2) + + +@then("we see transaction began") +def step_see_transaction_began(context): + """ + Wait to see transaction began. + """ + wrappers.expect_pager(context, "BEGIN\r\n", timeout=2) + + +@then("we see transaction rolled back") +def step_see_transaction_rolled_back(context): + """ + Wait to see transaction rollback. + """ + wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2) diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py new file mode 100644 index 0000000..302cab9 --- /dev/null +++ b/tests/features/steps/expanded.py @@ -0,0 +1,70 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +from behave import when, then +from textwrap import dedent +import wrappers + + +@when("we prepare the test data") +def step_prepare_data(context): + """Create table, insert a record.""" + context.cli.sendline("drop table if exists a;") + wrappers.expect_exact( + context, + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", + timeout=2, + ) + context.cli.sendline("y") + + wrappers.wait_prompt(context) + context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));") + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) + context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""") + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) + + +@when("we set expanded {mode}") +def step_set_expanded(context, mode): + """Set expanded to mode.""" + context.cli.sendline("\\" + f"x {mode}") + wrappers.expect_exact(context, "Expanded display is", timeout=2) + wrappers.wait_prompt(context) + + +@then("we see {which} data selected") +def step_see_data(context, which): + """Select data from expanded test table.""" + if which == "expanded": + wrappers.expect_pager( + context, + dedent( + """\ + -[ RECORD 1 ]-------------------------\r + x | 1\r + y | 1.0\r + z | 1.0000\r + SELECT 1\r + """ + ), + timeout=1, + ) + else: + wrappers.expect_pager( + context, + dedent( + """\ + +---+-----+--------+\r + | x | y | z |\r + |---+-----+--------|\r + | 1 | 1.0 | 1.0000 |\r + +---+-----+--------+\r + SELECT 1\r + """ + ), + timeout=1, + ) diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py new file mode 100644 index 0000000..a614490 --- /dev/null +++ b/tests/features/steps/iocommands.py @@ -0,0 +1,80 @@ +import os +import os.path + +from behave import when, then +import wrappers + + +@when("we start external editor providing a file name") +def step_edit_file(context): + """Edit file with external editor.""" + context.editor_file_name = os.path.join( + context.package_root, "test_file_{0}.sql".format(context.conf["vi"]) + ) + if os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name))) + wrappers.expect_exact( + context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2 + ) + wrappers.expect_exact(context, ":", timeout=2) + + +@when("we type sql in the editor") +def step_edit_type_sql(context): + context.cli.sendline("i") + context.cli.sendline("select * from abc") + context.cli.sendline(".") + wrappers.expect_exact(context, ":", timeout=2) + + +@when("we exit the editor") +def step_edit_quit(context): + context.cli.sendline("x") + wrappers.expect_exact(context, "written", timeout=2) + + +@then("we see the sql in prompt") +def step_edit_done_sql(context): + for match in "select * from abc".split(" "): + wrappers.expect_exact(context, match, timeout=1) + # Cleanup the command line. + context.cli.sendcontrol("c") + # Cleanup the edited file. + if context.editor_file_name and os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.atprompt = True + + +@when("we tee output") +def step_tee_ouptut(context): + context.tee_file_name = os.path.join( + context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]) + ) + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name))) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + wrappers.expect_exact(context, "Writing to file", timeout=5) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + wrappers.expect_exact(context, "Time", timeout=5) + + +@when('we query "select 123456"') +def step_query_select_123456(context): + context.cli.sendline("select 123456") + + +@when("we stop teeing output") +def step_notee_output(context): + context.cli.sendline(r"\o") + wrappers.expect_exact(context, "Time", timeout=5) + + +@then("we see 123456 in tee output") +def step_see_123456_in_ouput(context): + with open(context.tee_file_name) as f: + assert "123456" in f.read() + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.atprompt = True diff --git a/tests/features/steps/named_queries.py b/tests/features/steps/named_queries.py new file mode 100644 index 0000000..3f52859 --- /dev/null +++ b/tests/features/steps/named_queries.py @@ -0,0 +1,57 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when("we save a named query") +def step_save_named_query(context): + """ + Send \ns command + """ + context.cli.sendline("\\ns foo SELECT 12345") + + +@when("we use a named query") +def step_use_named_query(context): + """ + Send \n command + """ + context.cli.sendline("\\n foo") + + +@when("we delete a named query") +def step_delete_named_query(context): + """ + Send \nd command + """ + context.cli.sendline("\\nd foo") + + +@then("we see the named query saved") +def step_see_named_query_saved(context): + """ + Wait to see query saved. + """ + wrappers.expect_exact(context, "Saved.", timeout=2) + + +@then("we see the named query executed") +def step_see_named_query_executed(context): + """ + Wait to see select output. + """ + wrappers.expect_exact(context, "12345", timeout=1) + wrappers.expect_exact(context, "SELECT 1", timeout=1) + + +@then("we see the named query deleted") +def step_see_named_query_deleted(context): + """ + Wait to see query deleted. + """ + wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1) diff --git a/tests/features/steps/pgbouncer.py b/tests/features/steps/pgbouncer.py new file mode 100644 index 0000000..f156982 --- /dev/null +++ b/tests/features/steps/pgbouncer.py @@ -0,0 +1,22 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when('we send "show help" command') +def step_send_help_command(context): + context.cli.sendline("show help") + + +@then("we see the pgbouncer help output") +def see_pgbouncer_help(context): + wrappers.expect_exact( + context, + "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION", + timeout=3, + ) diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py new file mode 100644 index 0000000..a85f371 --- /dev/null +++ b/tests/features/steps/specials.py @@ -0,0 +1,31 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when("we refresh completions") +def step_refresh_completions(context): + """ + Send refresh command. + """ + context.cli.sendline("\\refresh") + + +@then("we see completions refresh started") +def step_see_refresh_started(context): + """ + Wait to see refresh output. + """ + wrappers.expect_pager( + context, + [ + "Auto-completion refresh started in the background.\r\n", + "Auto-completion refresh restarted.\r\n", + ], + timeout=2, + ) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py new file mode 100644 index 0000000..3ebcc92 --- /dev/null +++ b/tests/features/steps/wrappers.py @@ -0,0 +1,71 @@ +import re +import pexpect +from pgcli.main import COLOR_CODE_REGEX +import textwrap + +from io import StringIO + + +def expect_exact(context, expected, timeout): + timedout = False + try: + context.cli.expect_exact(expected, timeout=timeout) + except pexpect.TIMEOUT: + timedout = True + if timedout: + # Strip color codes out of the output. + actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before) + raise Exception( + textwrap.dedent( + """\ + Expected: + --- + {0!r} + --- + Actual: + --- + {1!r} + --- + Full log: + --- + {2!r} + --- + """ + ).format(expected, actual, context.logfile.getvalue()) + ) + + +def expect_pager(context, expected, timeout): + formatted = expected if isinstance(expected, list) else [expected] + formatted = [ + f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" + for t in formatted + ] + + expect_exact( + context, + formatted, + timeout=timeout, + ) + + +def run_cli(context, run_args=None, prompt_check=True, currentdb=None): + """Run the process using pexpect.""" + run_args = run_args or [] + cli_cmd = context.conf.get("cli_command") + cmd_parts = [cli_cmd] + run_args + cmd = " ".join(cmd_parts) + context.cli = pexpect.spawnu(cmd, cwd=context.package_root) + context.logfile = StringIO() + context.cli.logfile = context.logfile + context.exit_sent = False + context.currentdb = currentdb or context.conf["dbname"] + context.cli.sendline(r"\pset pager always") + if prompt_check: + wait_prompt(context) + + +def wait_prompt(context): + """Make sure prompt is displayed.""" + prompt_str = "{0}>".format(context.currentdb) + expect_exact(context, [prompt_str + " ", prompt_str, pexpect.EOF], timeout=3) diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py new file mode 100644 index 0000000..51d4909 --- /dev/null +++ b/tests/features/wrappager.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +import sys + + +def wrappager(boundary): + print(boundary) + while 1: + buf = sys.stdin.read(2048) + if not buf: + break + sys.stdout.write(buf) + print(boundary) + + +if __name__ == "__main__": + wrappager(sys.argv[1]) diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py new file mode 100644 index 0000000..9bad579 --- /dev/null +++ b/tests/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py new file mode 100644 index 0000000..016ed95 --- /dev/null +++ b/tests/formatter/test_sqlformatter.py @@ -0,0 +1,111 @@ +# coding=utf-8 + +from pgcli.packages.formatter.sqlformatter import escape_for_sql_statement + +from cli_helpers.tabular_output import TabularOutputFormatter +from pgcli.packages.formatter.sqlformatter import adapter, register_new_formatter + + +def test_escape_for_sql_statement_bytes(): + bts = b"837124ab3e8dc0f" + escaped_bytes = escape_for_sql_statement(bts) + assert escaped_bytes == "X'383337313234616233653864633066'" + + +def test_escape_for_sql_statement_number(): + num = 2981 + escaped_bytes = escape_for_sql_statement(num) + assert escaped_bytes == "'2981'" + + +def test_escape_for_sql_statement_str(): + example_str = "example str" + escaped_bytes = escape_for_sql_statement(example_str) + assert escaped_bytes == "'example str'" + + +def test_output_sql_insert(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + None, + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-insert" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + expected = [ + 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, " + + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + ";", + ] + assert expected == output_list + + +def test_output_sql_update(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-update" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + print(output_list) + expected = [ + 'UPDATE "user" SET', + " \"name\" = 'Jackson'", + ", \"email\" = 'jackson_test@gmail.com'", + ", \"phone\" = '132454789'", + ", \"description\" = ''", + ", \"created_at\" = '2022-09-09 19:44:32.712343+08'", + ", \"updated_at\" = '2022-09-09 19:44:32.712343+08'", + "WHERE \"id\" = '1';", + ] + assert expected == output_list diff --git a/tests/metadata.py b/tests/metadata.py new file mode 100644 index 0000000..4ebcccd --- /dev/null +++ b/tests/metadata.py @@ -0,0 +1,255 @@ +from functools import partial +from itertools import product +from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +from unittest.mock import Mock +import pytest + +parametrize = pytest.mark.parametrize + +qual = ["if_more_than_one_table", "always"] +no_qual = ["if_more_than_one_table", "never"] + + +def escape(name): + if not name.islower() or name in ("select", "localtimestamp"): + return '"' + name + '"' + return name + + +def completion(display_meta, text, pos=0): + return Completion(text, start_position=pos, display_meta=display_meta) + + +def function(text, pos=0, display=None): + return Completion( + text, display=display or text, start_position=pos, display_meta="function" + ) + + +def get_result(completer, text, position=None): + position = len(text) if position is None else position + return completer.get_completions( + Document(text=text, cursor_position=position), Mock() + ) + + +def result_set(completer, text, position=None): + return set(get_result(completer, text, position)) + + +# The code below is quivalent to +# def schema(text, pos=0): +# return completion('schema', text, pos) +# and so on +schema = partial(completion, "schema") +table = partial(completion, "table") +view = partial(completion, "view") +column = partial(completion, "column") +keyword = partial(completion, "keyword") +datatype = partial(completion, "datatype") +alias = partial(completion, "table alias") +name_join = partial(completion, "name join") +fk_join = partial(completion, "fk join") +join = partial(completion, "join") + + +def wildcard_expansion(cols, pos=-1): + return Completion(cols, start_position=pos, display_meta="columns", display="*") + + +class MetaData: + def __init__(self, metadata): + self.metadata = metadata + + def builtin_functions(self, pos=0): + return [function(f, pos) for f in self.completer.functions] + + def builtin_datatypes(self, pos=0): + return [datatype(dt, pos) for dt in self.completer.datatypes] + + def keywords(self, pos=0): + return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()] + + def specials(self, pos=0): + return [ + Completion(text=k, start_position=pos, display_meta=v.description) + for k, v in self.completer.pgspecial.commands.items() + ] + + def columns(self, tbl, parent="public", typ="tables", pos=0): + if typ == "functions": + fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0] + cols = fun[1] + else: + cols = self.metadata[typ][parent][tbl] + return [column(escape(col), pos) for col in cols] + + def datatypes(self, parent="public", pos=0): + return [ + datatype(escape(x), pos) + for x in self.metadata.get("datatypes", {}).get(parent, []) + ] + + def tables(self, parent="public", pos=0): + return [ + table(escape(x), pos) + for x in self.metadata.get("tables", {}).get(parent, []) + ] + + def views(self, parent="public", pos=0): + return [ + view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, []) + ] + + def functions(self, parent="public", pos=0): + return [ + function( + escape(x[0]) + + "(" + + ", ".join( + arg_name + " := " + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + pos, + escape(x[0]) + + "(" + + ", ".join( + arg_name + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + ) + for x in self.metadata.get("functions", {}).get(parent, []) + ] + + def schemas(self, pos=0): + schemas = {sch for schs in self.metadata.values() for sch in schs} + return [schema(escape(s), pos=pos) for s in schemas] + + def functions_and_keywords(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.builtin_functions(pos) + + self.keywords(pos) + ) + + # Note that the filtering parameters here only apply to the columns + def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0): + return self.functions_and_keywords(pos=pos) + self.columns( + tbl, parent, typ, pos + ) + + def from_clause_items(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.views(parent, pos) + + self.tables(parent, pos) + ) + + def schemas_and_from_clause_items(self, parent="public", pos=0): + return self.from_clause_items(parent, pos) + self.schemas(pos) + + def types(self, parent="public", pos=0): + return self.datatypes(parent, pos) + self.tables(parent, pos) + + @property + def completer(self): + return self.get_completer() + + def get_completers(self, casing): + """ + Returns a function taking three bools `casing`, `filtr`, `aliasing` and + the list `qualify`, all defaulting to None. + Returns a list of completers. + These parameters specify the allowed values for the corresponding + completer parameters, `None` meaning any, i.e. (None, None, None, None) + results in all 24 possible completers, whereas e.g. + (True, False, True, ['never']) results in the one completer with + casing, without `search_path` filtering of objects, with table + aliasing, and without column qualification. + """ + + def _cfg(_casing, filtr, aliasing, qualify): + cfg = {"settings": {}} + if _casing: + cfg["casing"] = casing + cfg["settings"]["search_path_filter"] = filtr + cfg["settings"]["generate_aliases"] = aliasing + cfg["settings"]["qualify_columns"] = qualify + return cfg + + def _cfgs(casing, filtr, aliasing, qualify): + casings = [True, False] if casing is None else [casing] + filtrs = [True, False] if filtr is None else [filtr] + aliases = [True, False] if aliasing is None else [aliasing] + qualifys = qualify or ["always", "if_more_than_one_table", "never"] + return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)] + + def completers(casing=None, filtr=None, aliasing=None, qualify=None): + get_comp = self.get_completer + return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)] + + return completers + + def _make_col(self, sch, tbl, col): + defaults = self.metadata.get("defaults", {}).get(sch, {}) + return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col))) + + def get_completer(self, settings=None, casing=None): + metadata = self.metadata + from pgcli.pgcompleter import PGCompleter + from pgspecial import PGSpecial + + comp = PGCompleter( + smart_completion=True, settings=settings, pgspecial=PGSpecial() + ) + + schemata, tables, tbl_cols, views, view_cols = [], [], [], [], [] + + for sch, tbls in metadata["tables"].items(): + schemata.append(sch) + + for tbl, cols in tbls.items(): + tables.append((sch, tbl)) + # Let all columns be text columns + tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + for sch, tbls in metadata.get("views", {}).items(): + for tbl, cols in tbls.items(): + views.append((sch, tbl)) + # Let all columns be text columns + view_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + functions = [ + FunctionMetadata(sch, *func_meta, arg_defaults=None) + for sch, funcs in metadata["functions"].items() + for func_meta in funcs + ] + + datatypes = [ + (sch, typ) + for sch, datatypes in metadata["datatypes"].items() + for typ in datatypes + ] + + foreignkeys = [ + ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks + ] + + comp.extend_schemata(schemata) + comp.extend_relations(tables, kind="tables") + comp.extend_relations(views, kind="views") + comp.extend_columns(tbl_cols, kind="tables") + comp.extend_columns(view_cols, kind="views") + comp.extend_functions(functions) + comp.extend_datatypes(datatypes) + comp.extend_foreignkeys(foreignkeys) + comp.set_search_path(["public"]) + comp.extend_casing(casing or []) + + return comp diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py new file mode 100644 index 0000000..3e89cca --- /dev/null +++ b/tests/parseutils/test_ctes.py @@ -0,0 +1,137 @@ +import pytest +from sqlparse import parse +from pgcli.packages.parseutils.ctes import ( + token_start_pos, + extract_ctes, + extract_column_names as _extract_column_names, +) + + +def extract_column_names(sql): + p = parse(sql)[0] + return _extract_column_names(p) + + +def test_token_str_pos(): + sql = "SELECT * FROM xxx" + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ") + + sql = "SELECT * FROM \nxxx" + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n") + + +def test_single_column_name_extraction(): + sql = "SELECT abc FROM xxx" + assert extract_column_names(sql) == ("abc",) + + +def test_aliased_single_column_name_extraction(): + sql = "SELECT abc def FROM xxx" + assert extract_column_names(sql) == ("def",) + + +def test_aliased_expression_name_extraction(): + sql = "SELECT 99 abc FROM xxx" + assert extract_column_names(sql) == ("abc",) + + +def test_multiple_column_name_extraction(): + sql = "SELECT abc, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") + + +def test_missing_column_name_handled_gracefully(): + sql = "SELECT abc, 99 FROM xxx" + assert extract_column_names(sql) == ("abc",) + + sql = "SELECT abc, 99, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") + + +def test_aliased_multiple_column_name_extraction(): + sql = "SELECT abc def, ghi jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") + + +def test_table_qualified_column_name_extraction(): + sql = "SELECT abc.def, ghi.jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") + + +@pytest.mark.parametrize( + "sql", + [ + "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y", + "DELETE FROM foo WHERE x > y RETURNING x, y", + "UPDATE foo SET x = 9 RETURNING x, y", + ], +) +def test_extract_column_names_from_returning_clause(sql): + assert extract_column_names(sql) == ("x", "y") + + +def test_simple_cte_extraction(): + sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a" + start_pos = len("WITH a AS ") + stop_pos = len("WITH a AS (SELECT abc FROM xxx)") + ctes, remainder = extract_ctes(sql) + + assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" + + +def test_cte_extraction_around_comments(): + sql = """--blah blah blah + WITH a AS (SELECT abc def FROM x) + SELECT * FROM a""" + start_pos = len( + """--blah blah blah + WITH a AS """ + ) + stop_pos = len( + """--blah blah blah + WITH a AS (SELECT abc def FROM x)""" + ) + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" + + +def test_multiple_cte_extraction(): + sql = """WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y) + SELECT * FROM a, b""" + + start1 = len( + """WITH + x AS """ + ) + + stop1 = len( + """WITH + x AS (SELECT abc, def FROM x)""" + ) + + start2 = len( + """WITH + x AS (SELECT abc, def FROM x), + y AS """ + ) + + stop2 = len( + """WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y)""" + ) + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == ( + ("x", ("abc", "def"), start1, stop1), + ("y", ("ghi", "jkl"), start2, stop2), + ) diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py new file mode 100644 index 0000000..0350e2a --- /dev/null +++ b/tests/parseutils/test_function_metadata.py @@ -0,0 +1,19 @@ +from pgcli.packages.parseutils.meta import FunctionMetadata + + +def test_function_metadata_eq(): + f1 = FunctionMetadata( + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + f2 = FunctionMetadata( + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + f3 = FunctionMetadata( + "s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + assert f1 == f2 + assert f1 != f3 + assert not (f1 != f2) + assert not (f1 == f3) + assert hash(f1) == hash(f2) + assert hash(f1) != hash(f3) diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py new file mode 100644 index 0000000..349cbd0 --- /dev/null +++ b/tests/parseutils/test_parseutils.py @@ -0,0 +1,310 @@ +import pytest +from pgcli.packages.parseutils import ( + is_destructive, + parse_destructive_warning, + BASE_KEYWORDS, + ALL_KEYWORDS, +) +from pgcli.packages.parseutils.tables import extract_tables +from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote + + +def test_empty_string(): + tables = extract_tables("") + assert tables == () + + +def test_simple_select_single_table(): + tables = extract_tables("select * from abc") + assert tables == ((None, "abc", None, False),) + + +@pytest.mark.parametrize( + "sql", ['select * from "abc"."def"', 'select * from abc."def"'] +) +def test_simple_select_single_table_schema_qualified_quoted_table(sql): + tables = extract_tables(sql) + assert tables == (("abc", "def", '"def"', False),) + + +@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def']) +def test_simple_select_single_table_schema_qualified(sql): + tables = extract_tables(sql) + assert tables == (("abc", "def", None, False),) + + +def test_simple_select_single_table_double_quoted(): + tables = extract_tables('select * from "Abc"') + assert tables == ((None, "Abc", None, False),) + + +def test_simple_select_multiple_tables(): + tables = extract_tables("select * from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_simple_select_multiple_tables_double_quoted(): + tables = extract_tables('select * from "Abc", "Def"') + assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)} + + +def test_simple_select_single_table_deouble_quoted_aliased(): + tables = extract_tables('select * from "Abc" a') + assert tables == ((None, "Abc", "a", False),) + + +def test_simple_select_multiple_tables_deouble_quoted_aliased(): + tables = extract_tables('select * from "Abc" a, "Def" d') + assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)} + + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables("select * from abc.def, ghi.jkl") + assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)} + + +def test_simple_select_with_cols_single_table(): + tables = extract_tables("select a,b from abc") + assert tables == ((None, "abc", None, False),) + + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables("select a,b from abc.def") + assert tables == (("abc", "def", None, False),) + + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables("select a,b from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_simple_select_with_cols_multiple_qualified_tables(): + tables = extract_tables("select a,b from abc.def, def.ghi") + assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)} + + +def test_select_with_hanging_comma_single_table(): + tables = extract_tables("select a, from abc") + assert tables == ((None, "abc", None, False),) + + +def test_select_with_hanging_comma_multiple_tables(): + tables = extract_tables("select a, from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") + assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)} + + +def test_simple_insert_single_table(): + tables = extract_tables('insert into abc (id, name) values (1, "def")') + + # sqlparse mistakenly assigns an alias to the table + # AND mistakenly identifies the field list as + # assert tables == ((None, 'abc', 'abc', False),) + + assert tables == ((None, "abc", "abc", False),) + + +@pytest.mark.xfail +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == (("abc", "def", None, False),) + + +def test_simple_update_table_no_schema(): + tables = extract_tables("update abc set id = 1") + assert tables == ((None, "abc", None, False),) + + +def test_simple_update_table_with_schema(): + tables = extract_tables("update abc.def set id = 1") + assert tables == (("abc", "def", None, False),) + + +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) +def test_join_table(join_type): + sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num" + tables = extract_tables(sql) + assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)} + + +def test_join_table_schema_qualified(): + tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") + assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)} + + +def test_incomplete_join_clause(): + sql = """select a.x, b.y + from abc a join bcd b + on a.id = """ + tables = extract_tables(sql) + assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False)) + + +def test_join_as_table(): + tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == ((None, "my_table", "m", False),) + + +def test_multiple_joins(): + sql = """select * from t1 + inner join t2 ON + t1.id = t2.t1_id + inner join t3 ON + t2.id = t3.""" + tables = extract_tables(sql) + assert tables == ( + (None, "t1", None, False), + (None, "t2", None, False), + (None, "t3", None, False), + ) + + +def test_subselect_tables(): + sql = "SELECT * FROM (SELECT FROM abc" + tables = extract_tables(sql) + assert tables == ((None, "abc", None, False),) + + +@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"]) +def test_extract_no_tables(text): + tables = extract_tables(text) + assert tables == tuple() + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo({arg_list})") + assert tables == ((None, "foo", None, True),) + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_schema_qualified_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})") + assert tables == (("foo", "bar", None, True),) + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_aliased_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar") + assert tables == ((None, "foo", "bar", True),) + + +def test_simple_table_and_function(): + tables = extract_tables("SELECT * FROM foo JOIN bar()") + assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)} + + +def test_complex_table_and_function(): + tables = extract_tables( + """SELECT * FROM foo.bar baz + JOIN bar.qux(x, y, z) quux""" + ) + assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)} + + +def test_find_prev_keyword_using(): + q = "select * from tbl1 inner join tbl2 using (col1, " + kw, q2 = find_prev_keyword(q) + assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using (" + + +@pytest.mark.parametrize( + "sql", + [ + "select * from foo where bar", + "select * from foo where bar = 1 and baz or ", + "select * from foo where bar = 1 and baz between qux and ", + ], +) +def test_find_prev_keyword_where(sql): + kw, stripped = find_prev_keyword(sql) + assert kw.value == "where" and stripped == "select * from foo where" + + +@pytest.mark.parametrize( + "sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "] +) +def test_find_prev_keyword_open_parens(sql): + kw, _ = find_prev_keyword(sql) + assert kw.value == "(" + + +@pytest.mark.parametrize( + "sql", + [ + "", + "$$ foo $$", + "$$ 'foo' $$", + '$$ "foo" $$', + "$$ $a$ $$", + "$a$ $$ $a$", + "foo bar $$ baz $$", + ], +) +def test_is_open_quote__closed(sql): + assert not is_open_quote(sql) + + +@pytest.mark.parametrize( + "sql", + [ + "$$", + ";;;$$", + "foo $$ bar $$; foo $$", + "$$ foo $a$", + "foo 'bar baz", + "$a$ foo ", + '$$ "foo" ', + "$$ $a$ ", + "foo bar $$ baz", + ], +) +def test_is_open_quote__open(sql): + assert is_open_quote(sql) + + +@pytest.mark.parametrize( + ("sql", "keywords", "expected"), + [ + ("update abc set x = 1", ALL_KEYWORDS, True), + ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True), + ("update abc set x = 1", BASE_KEYWORDS, True), + ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False), + ("select x, y, z from abc", ALL_KEYWORDS, False), + ("drop abc", ALL_KEYWORDS, True), + ("alter abc", ALL_KEYWORDS, True), + ("delete abc", ALL_KEYWORDS, True), + ("truncate abc", ALL_KEYWORDS, True), + ("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ], +) +def test_is_destructive(sql, keywords, expected): + assert is_destructive(sql, keywords) == expected + + +@pytest.mark.parametrize( + ("warning_level", "expected"), + [ + ("true", ALL_KEYWORDS), + ("false", []), + ("all", ALL_KEYWORDS), + ("moderate", BASE_KEYWORDS), + ("off", []), + ("", []), + (None, []), + (ALL_KEYWORDS, ALL_KEYWORDS), + (BASE_KEYWORDS, BASE_KEYWORDS), + ("insert", ["insert"]), + ("drop,alter,delete", ["drop", "alter", "delete"]), + (["drop", "alter", "delete"], ["drop", "alter", "delete"]), + ], +) +def test_parse_destructive_warning(warning_level, expected): + assert parse_destructive_warning(warning_level) == expected diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..f787740 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts=--capture=sys --showlocals \ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..a517a89 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,40 @@ +import pytest +from unittest import mock +from pgcli import auth + + +@pytest.mark.parametrize("enabled,call_count", [(True, 1), (False, 0)]) +def test_keyring_initialize(enabled, call_count): + logger = mock.MagicMock() + + with mock.patch("importlib.import_module", return_value=True) as import_method: + auth.keyring_initialize(enabled, logger=logger) + assert import_method.call_count == call_count + + +def test_keyring_get_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.get_password", return_value="abc123"): + assert auth.keyring_get_password("test") == "abc123" + + +def test_keyring_get_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch( + "pgcli.auth.keyring.get_password", side_effect=Exception("Boom!") + ): + assert auth.keyring_get_password("test") == "" + + +def test_keyring_set_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.set_password"): + auth.keyring_set_password("test", "abc123") + + +def test_keyring_set_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch( + "pgcli.auth.keyring.set_password", side_effect=Exception("Boom!") + ): + auth.keyring_set_password("test", "abc123") diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py new file mode 100644 index 0000000..a5529d6 --- /dev/null +++ b/tests/test_completion_refresher.py @@ -0,0 +1,95 @@ +import time +import pytest +from unittest.mock import Mock, patch + + +@pytest.fixture +def refresher(): + from pgcli.completion_refresher import CompletionRefresher + + return CompletionRefresher() + + +def test_ctor(refresher): + """ + Refresher object should contain a few handlers + :param refresher: + :return: + """ + assert len(refresher.refreshers) > 0 + actual_handlers = list(refresher.refreshers.keys()) + expected_handlers = [ + "schemata", + "tables", + "views", + "types", + "databases", + "casing", + "functions", + ] + assert expected_handlers == actual_handlers + + +def test_refresh_called_once(refresher): + """ + + :param refresher: + :return: + """ + callbacks = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + special = Mock() + + with patch.object(refresher, "_bg_refresh") as bg_refresh: + actual = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual) == 1 + assert len(actual[0]) == 4 + assert actual[0][3] == "Auto-completion refresh started in the background." + bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None) + + +def test_refresh_called_twice(refresher): + """ + If refresh is called a second time, it should be restarted + :param refresher: + :return: + """ + callbacks = Mock() + + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + special = Mock() + + def dummy_bg_refresh(*args): + time.sleep(3) # seconds + + refresher._bg_refresh = dummy_bg_refresh + + actual1 = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual1) == 1 + assert len(actual1[0]) == 4 + assert actual1[0][3] == "Auto-completion refresh started in the background." + + actual2 = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual2) == 1 + assert len(actual2[0]) == 4 + assert actual2[0][3] == "Auto-completion refresh restarted." + + +def test_refresh_with_callbacks(refresher): + """ + Callbacks must be called + :param refresher: + """ + callbacks = [Mock()] + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + pgexecute.extra_args = {} + special = Mock() + + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert callbacks[0].call_count == 1 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..08fe74e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,43 @@ +import io +import os +import stat + +import pytest + +from pgcli.config import ensure_dir_exists, skip_initial_comment + + +def test_ensure_file_parent(tmpdir): + subdir = tmpdir.join("subdir") + rcfile = subdir.join("rcfile") + ensure_dir_exists(str(rcfile)) + + +def test_ensure_existing_dir(tmpdir): + rcfile = str(tmpdir.mkdir("subdir").join("rcfile")) + + # should just not raise + ensure_dir_exists(rcfile) + + +def test_ensure_other_create_error(tmpdir): + subdir = tmpdir.join('subdir"') + rcfile = subdir.join("rcfile") + + # trigger an oserror that isn't "directory already exists" + os.chmod(str(tmpdir), stat.S_IREAD) + + with pytest.raises(OSError): + ensure_dir_exists(str(rcfile)) + + +@pytest.mark.parametrize( + "text, skipped_lines", + ( + ("abc\n", 1), + ("#[section]\ndef\n[section]", 2), + ("[section]", 0), + ), +) +def test_skip_initial_comment(text, skipped_lines): + assert skip_initial_comment(io.StringIO(text)) == skipped_lines diff --git a/tests/test_fuzzy_completion.py b/tests/test_fuzzy_completion.py new file mode 100644 index 0000000..8f8f2cd --- /dev/null +++ b/tests/test_fuzzy_completion.py @@ -0,0 +1,87 @@ +import pytest + + +@pytest.fixture +def completer(): + import pgcli.pgcompleter as pgcompleter + + return pgcompleter.PGCompleter() + + +def test_ranking_ignores_identifier_quotes(completer): + """When calculating result rank, identifier quotes should be ignored. + + The result ranking algorithm ignores identifier quotes. Without this + correction, the match "user", which Postgres requires to be quoted + since it is also a reserved word, would incorrectly fall below the + match user_action because the literal quotation marks in "user" + alter the position of the match. + + This test checks that the fuzzy ranking algorithm correctly ignores + quotation marks when computing match ranks. + + """ + + text = "user" + collection = ["user_action", '"user"'] + matches = completer.find_matches(text, collection) + assert len(matches) == 2 + + +def test_ranking_based_on_shortest_match(completer): + """Fuzzy result rank should be based on shortest match. + + Result ranking in fuzzy searching is partially based on the length + of matches: shorter matches are considered more relevant than + longer ones. When searching for the text 'user', the length + component of the match 'user_group' could be either 4 ('user') or + 7 ('user_gr'). + + This test checks that the fuzzy ranking algorithm uses the shorter + match when calculating result rank. + + """ + + text = "user" + collection = ["api_user", "user_group"] + matches = completer.find_matches(text, collection) + + assert matches[1].priority > matches[0].priority + + +@pytest.mark.parametrize( + "collection", + [["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]], +) +def test_should_break_ties_using_lexical_order(completer, collection): + """Fuzzy result rank should use lexical order to break ties. + + When fuzzy matching, if multiple matches have the same match length and + start position, present them in lexical (rather than arbitrary) order. For + example, if we have tables 'user', 'user_action', and 'user_group', a + search for the text 'user' should present these tables in this order. + + The input collections to this test are out of order; each run checks that + the search text 'user' results in the input tables being reordered + lexically. + + """ + + text = "user" + matches = completer.find_matches(text, collection) + + assert matches[1].priority > matches[0].priority + + +def test_matching_should_be_case_insensitive(completer): + """Fuzzy matching should keep matches even if letter casing doesn't match. + + This test checks that variations of the text which have different casing + are still matched. + """ + + text = "foo" + collection = ["Foo", "FOO", "fOO"] + matches = completer.find_matches(text, collection) + + assert len(matches) == 3 diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..cbf20a6 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,490 @@ +import os +import platform +from unittest import mock + +import pytest + +try: + import setproctitle +except ImportError: + setproctitle = None + +from pgcli.main import ( + obfuscate_process_password, + format_output, + PGCli, + OutputSettings, + COLOR_CODE_REGEX, +) +from pgcli.pgexecute import PGExecute +from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS +from utils import dbtest, run +from collections import namedtuple + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows") +@pytest.mark.skipif(not setproctitle, reason="setproctitle not available") +def test_obfuscate_process_password(): + original_title = setproctitle.getproctitle() + + setproctitle.setproctitle("pgcli user=root password=secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx" + assert title == expected + + setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli postgres://root:xxxx@localhost/db" + assert title == expected + + setproctitle.setproctitle(original_title) + + +def test_format_output(): + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g") + results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + expected = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(results) == expected + + +def test_format_output_truncate_on(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10 + ) + results = format_output( + None, + [("first field value", "second field value")], + ["head1", "head2"], + None, + settings, + ) + expected = [ + "+------------+------------+", + "| head1 | head2 |", + "|------------+------------|", + "| first f... | second ... |", + "+------------+------------+", + ] + assert list(results) == expected + + +def test_format_output_truncate_off(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None + ) + long_field_value = ("first field " * 100).strip() + results = format_output(None, [(long_field_value,)], ["head1"], None, settings) + lines = list(results) + assert lines[3] == f"| {long_field_value} |" + + +@dbtest +def test_format_array_output(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement) + expected = [ + "+--------------+----------------------+--------------+", + "| bigint_array | nested_numeric_array | 配列 |", + "|--------------+----------------------+--------------|", + "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |", + "| {} | | {} |", + "+--------------+----------------------+--------------+", + "SELECT 2", + ] + assert list(results) == expected + + +@dbtest +def test_format_array_output_expanded(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement, expanded=True) + expected = [ + "-[ RECORD 1 ]-------------------------", + "bigint_array | {1,2,3}", + "nested_numeric_array | {{1,2},{3,4}}", + "配列 | {å,魚,текст}", + "-[ RECORD 2 ]-------------------------", + "bigint_array | {}", + "nested_numeric_array | ", + "配列 | {}", + "SELECT 2", + ] + assert "\n".join(results) == "\n".join(expected) + + +def test_format_output_auto_expand(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100 + ) + table_results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + table = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(table_results) == table + expanded_results = format_output( + "Title", + [("abc", "def")], + ["head1", "head2"], + "test status", + settings._replace(max_width=1), + ) + expanded = [ + "Title", + "-[ RECORD 1 ]-------------------------", + "head1 | abc", + "head2 | def", + "test status", + ] + assert "\n".join(expanded_results) == "\n".join(expanded) + + +termsize = namedtuple("termsize", ["rows", "columns"]) +test_line = "-" * 10 +test_data = [ + (10, 10, "\n".join([test_line] * 7)), + (10, 10, "\n".join([test_line] * 6)), + (10, 10, "\n".join([test_line] * 5)), + (10, 10, "-" * 11), + (10, 10, "-" * 10), + (10, 10, "-" * 9), +] + +# 4 lines are reserved at the bottom of the terminal for pgcli's prompt +use_pager_when_on = [True, True, False, True, False, False] + +# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL +test_ids = [ + "Output longer than terminal height", + "Output equal to terminal height", + "Output shorter than terminal height", + "Output longer than terminal width", + "Output equal to terminal width", + "Output shorter than terminal width", +] + + +@pytest.fixture +def pset_pager_mocks(): + cli = PGCli() + cli.watch_command = None + with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( + "pgcli.main.click.echo_via_pager" + ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: + yield cli, mock_echo, mock_echo_via_pager, mock_app + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF): + cli.echo_via_pager(text) + + mock_echo.assert_called() + mock_echo_via_pager.assert_not_called() + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS): + cli.echo_via_pager(text) + + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + + +pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)] + + +@pytest.mark.parametrize( + "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids +) +def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT): + cli.echo_via_pager(text) + + if use_pager: + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + else: + mock_echo_via_pager.assert_not_called() + mock_echo.assert_called() + + +@pytest.mark.parametrize( + "text,expected_length", + [ + ( + "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", + 78, + ), + ("=\u001b[m=", 2), + ("-\u001b]23\u0007-", 2), + ], +) +def test_color_pattern(text, expected_length, pset_pager_mocks): + cli = pset_pager_mocks[0] + assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length + + +@dbtest +def test_i_works(tmpdir, executor): + sqlfile = tmpdir.join("test.sql") + sqlfile.write("SELECT NOW()") + rcfile = str(tmpdir.join("rcfile")) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + statement = r"\i {0}".format(sqlfile) + run(executor, statement, pgspecial=cli.pgspecial) + + +@dbtest +def test_echo_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\echo asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_qecho_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\qecho asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_watch_works(executor): + cli = PGCli(pgexecute=executor) + + def run_with_watch( + query, target_call_count=1, expected_output="", expected_timing=None + ): + """ + :param query: Input to the CLI + :param target_call_count: Number of times the user lets the command run before Ctrl-C + :param expected_output: Substring expected to be found for each executed query + :param expected_timing: value `time.sleep` expected to be called with on every invocation + """ + with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch( + "pgcli.main.sleep" + ) as mock_sleep: + mock_sleep.side_effect = [None] * (target_call_count - 1) + [ + KeyboardInterrupt + ] + cli.handle_watch_command(query) + # Validate that sleep was called with the right timing + for i in range(target_call_count - 1): + assert mock_sleep.call_args_list[i][0][0] == expected_timing + # Validate that the output of the query was expected + assert mock_echo.call_count == target_call_count + for i in range(target_call_count): + assert expected_output in mock_echo.call_args_list[i][0][0] + + # With no history, it errors. + with mock.patch("pgcli.main.click.secho") as mock_secho: + cli.handle_watch_command(r"\watch 2") + mock_secho.assert_called() + assert ( + r"\watch cannot be used with an empty query" + in mock_secho.call_args_list[0][0][0] + ) + + # Usage 1: Run a query and then re-run it with \watch across two prompts. + run_with_watch("SELECT 111", expected_output="111") + run_with_watch( + "\\watch 10", target_call_count=2, expected_output="111", expected_timing=10 + ) + + # Usage 2: Run a query and \watch via the same prompt. + run_with_watch( + "SELECT 222; \\watch 4", + target_call_count=3, + expected_output="222", + expected_timing=4, + ) + + # Usage 3: Re-run the last watched command with a new timing + run_with_watch( + "\\watch 5", target_call_count=4, expected_output="222", expected_timing=5 + ) + + +def test_missing_rc_dir(tmpdir): + rcfile = str(tmpdir.join("subdir").join("rcfile")) + + PGCli(pgclirc_file=rcfile) + assert os.path.exists(rcfile) + + +def test_quoted_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B") + mock_connect.assert_called_with( + database="testdb[", host="baz.com", user="bar^", passwd="]foo" + ) + + +def test_pg_service_file(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: + service_conf.write( + """File begins with a comment + that is not a comment + # or maybe a comment after all + because psql is crazy + + [myservice] + host=a_host + user=a_user + port=5433 + password=much_secure + dbname=a_dbname + + [my_other_service] + host=b_host + user=b_user + port=5435 + dbname=b_dbname + """ + ) + os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath + cli.connect_service("myservice", "another_user") + mock_connect.assert_called_with( + database="a_dbname", + host="a_host", + user="another_user", + port="5433", + passwd="much_secure", + ) + + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + os.environ["PGPASSWORD"] = "very_secure" + cli.connect_service("my_other_service", None) + mock_pgexecute.assert_called_with( + "b_dbname", + "b_user", + "very_secure", + "b_host", + "5435", + "", + application_name="pgcli", + ) + del os.environ["PGPASSWORD"] + del os.environ["PGSERVICEFILE"] + + +def test_ssl_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?" + "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" + ) + mock_connect.assert_called_with( + database="testdb[", + host="baz.com", + user="bar^", + passwd="]foo", + sslmode="verify-full", + sslcert="my.pem", + sslkey="my-key.pem", + sslrootcert="ca.pem", + ) + + +def test_port_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb") + mock_connect.assert_called_with( + database="testdb", host="baz.com", user="bar", passwd="foo", port="2543" + ) + + +def test_multihost_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb" + ) + mock_connect.assert_called_with( + database="testdb", + host="baz1.com,baz2.com,baz3.com", + user="bar", + passwd="foo", + port="2543,2543,2543", + ) + + +def test_application_name_db_uri(tmpdir): + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar@baz.com/?application_name=cow") + mock_pgexecute.assert_called_with( + "bar", "bar", "", "baz.com", "", "", application_name="cow" + ) diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py new file mode 100644 index 0000000..5b93661 --- /dev/null +++ b/tests/test_naive_completion.py @@ -0,0 +1,133 @@ +import pytest +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +from utils import completions_to_set + + +@pytest.fixture +def completer(): + import pgcli.pgcompleter as pgcompleter + + return pgcompleter.PGCompleter(smart_completion=False) + + +@pytest.fixture +def complete_event(): + from unittest.mock import Mock + + return Mock() + + +def test_empty_string_completion(completer, complete_event): + text = "" + position = 0 + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) + + +def test_select_keyword_completion(completer, complete_event): + text = "SEL" + position = len("SEL") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set([Completion(text="SELECT", start_position=-3)]) + + +def test_function_name_completion(completer, complete_event): + text = "SELECT MA" + position = len("SELECT MA") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set( + [ + Completion(text="MATERIALIZED VIEW", start_position=-2), + Completion(text="MAX", start_position=-2), + Completion(text="MAXEXTENTS", start_position=-2), + Completion(text="MAKE_DATE", start_position=-2), + Completion(text="MAKE_TIME", start_position=-2), + Completion(text="MAKE_TIMESTAMPTZ", start_position=-2), + Completion(text="MAKE_INTERVAL", start_position=-2), + Completion(text="MASKLEN", start_position=-2), + Completion(text="MAKE_TIMESTAMP", start_position=-2), + ] + ) + + +def test_column_name_completion(completer, complete_event): + text = "SELECT FROM users" + position = len("SELECT ") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) + + +def test_alter_well_known_keywords_completion(completer, complete_event): + text = "ALTER " + position = len(text) + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event, + smart_completion=True, + ) + ) + assert result > completions_to_set( + [ + Completion(text="DATABASE", display_meta="keyword"), + Completion(text="TABLE", display_meta="keyword"), + Completion(text="SYSTEM", display_meta="keyword"), + ] + ) + assert ( + completions_to_set([Completion(text="CREATE", display_meta="keyword")]) + not in result + ) + + +def test_special_name_completion(completer, complete_event): + text = "\\" + position = len("\\") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + # Special commands will NOT be suggested during naive completion mode. + assert result == completions_to_set([]) + + +def test_datatype_name_completion(completer, complete_event): + text = "SELECT price::IN" + position = len("SELECT price::IN") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event, + smart_completion=True, + ) + ) + assert result == completions_to_set( + [ + Completion(text="INET", display_meta="datatype"), + Completion(text="INT", display_meta="datatype"), + Completion(text="INT2", display_meta="datatype"), + Completion(text="INT4", display_meta="datatype"), + Completion(text="INT8", display_meta="datatype"), + Completion(text="INTEGER", display_meta="datatype"), + Completion(text="INTERNAL", display_meta="datatype"), + Completion(text="INTERVAL", display_meta="datatype"), + ] + ) diff --git a/tests/test_pgcompleter.py b/tests/test_pgcompleter.py new file mode 100644 index 0000000..909fa0b --- /dev/null +++ b/tests/test_pgcompleter.py @@ -0,0 +1,76 @@ +import pytest +from pgcli import pgcompleter + + +def test_load_alias_map_file_missing_file(): + with pytest.raises( + pgcompleter.InvalidMapFile, + match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$", + ): + pgcompleter.load_alias_map_file("/path/to/non-existent/file.json") + + +def test_load_alias_map_file_invalid_json(tmp_path): + fpath = tmp_path / "foo.json" + fpath.write_text("this is not valid json") + with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"): + pgcompleter.load_alias_map_file(str(fpath)) + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("SomE_Table", "SET"), + ("SOmeTabLe", "SOTL"), + ("someTable", "T"), + ], +) +def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("some_tab_le", "stl"), + ("s_ome_table", "sot"), + ("sometable", "s"), + ], +) +def test_generate_alias_uses_first_char_and_every_preceded_by_underscore( + table_name, alias +): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("some_table", {"some_table": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_can_use_alias_map(table_name, alias_map, alias): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("SomeTable", {"SomeTable": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_prefers_alias_over_upper_case_name( + table_name, alias_map, alias +): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("Some_tablE", "SE"), + ("SomeTab_le", "ST"), + ], +) +def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py new file mode 100644 index 0000000..636795b --- /dev/null +++ b/tests/test_pgexecute.py @@ -0,0 +1,773 @@ +from textwrap import dedent + +import psycopg +import pytest +from unittest.mock import patch, MagicMock +from pgspecial.main import PGSpecial, NO_QUERY +from utils import run, dbtest, requires_json, requires_jsonb + +from pgcli.main import PGCli +from pgcli.packages.parseutils.meta import FunctionMetadata + + +def function_meta_data( + func_name, + schema_name="public", + arg_names=None, + arg_types=None, + arg_modes=None, + return_type=None, + is_aggregate=False, + is_window=False, + is_set_returning=False, + is_extension=False, + arg_defaults=None, +): + return FunctionMetadata( + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, + ) + + +@dbtest +def test_conn(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ + +-----+ + | a | + |-----| + | abc | + +-----+ + SELECT 1""" + ) + + +@dbtest +def test_copy(executor): + executor_copy = executor.copy() + run(executor_copy, """create table test(a text)""") + run(executor_copy, """insert into test values('abc')""") + assert run(executor_copy, """select * from test""", join=True) == dedent( + """\ + +-----+ + | a | + |-----| + | abc | + +-----+ + SELECT 1""" + ) + + +@dbtest +def test_bools_are_treated_as_strings(executor): + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ + +------+ + | a | + |------| + | True | + +------+ + SELECT 1""" + ) + + +@dbtest +def test_expanded_slash_G(executor, pgspecial): + # Tests whether we reset the expanded output after a \G. + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + results = run(executor, r"""select * from test \G""", pgspecial=pgspecial) + assert pgspecial.expanded_output == False + + +@dbtest +def test_schemata_table_views_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + run(executor, "create view d as select 1 as e") + run(executor, "create schema schema1") + run(executor, "create table schema1.c (w text DEFAULT 'meow')") + run(executor, "create schema schema2") + + # schemata + # don't enforce all members of the schemas since they may include postgres + # temporary schemas + assert set(executor.schemata()) >= { + "public", + "pg_catalog", + "information_schema", + "schema1", + "schema2", + } + assert executor.search_path() == ["pg_catalog", "public"] + + # tables + assert set(executor.tables()) >= { + ("public", "a"), + ("public", "b"), + ("schema1", "c"), + } + + assert set(executor.table_columns()) >= { + ("public", "a", "x", "text", False, None), + ("public", "a", "y", "text", False, None), + ("public", "b", "z", "text", False, None), + ("schema1", "c", "w", "text", True, "'meow'::text"), + } + + # views + assert set(executor.views()) >= {("public", "d")} + + assert set(executor.view_columns()) >= { + ("public", "d", "e", "integer", False, None) + } + + +@dbtest +def test_foreign_key_query(executor): + run(executor, "create schema schema1") + run(executor, "create schema schema2") + run(executor, "create table schema1.parent(parentid int PRIMARY KEY)") + run( + executor, + "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)", + ) + + assert set(executor.foreignkeys()) >= { + ("schema1", "parent", "parentid", "schema2", "child", "motherid") + } + + +@dbtest +def test_functions_query(executor): + run( + executor, + """create function func1() returns int + language sql as $$select 1$$""", + ) + run(executor, "create schema schema1") + run( + executor, + """create function schema1.func2() returns int + language sql as $$select 2$$""", + ) + + run( + executor, + """create function func3() + returns table(x int, y int) language sql + as $$select 1, 2 from generate_series(1,5)$$;""", + ) + + run( + executor, + """create function func4(x int) returns setof int language sql + as $$select generate_series(1,5)$$;""", + ) + + funcs = set(executor.functions()) + assert funcs >= { + function_meta_data(func_name="func1", return_type="integer"), + function_meta_data( + func_name="func3", + arg_names=["x", "y"], + arg_types=["integer", "integer"], + arg_modes=["t", "t"], + return_type="record", + is_set_returning=True, + ), + function_meta_data( + schema_name="public", + func_name="func4", + arg_names=("x",), + arg_types=("integer",), + return_type="integer", + is_set_returning=True, + ), + function_meta_data( + schema_name="schema1", func_name="func2", return_type="integer" + ), + } + + +@dbtest +def test_datatypes_query(executor): + run(executor, "create type foo AS (a int, b text)") + + types = list(executor.datatypes()) + assert types == [("public", "foo")] + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert "_test_db" in databases + + +@dbtest +def test_invalid_syntax(executor, exception_formatter): + result = run(executor, "invalid syntax!", exception_formatter=exception_formatter) + assert 'syntax error at or near "invalid"' in result[0] + + +@dbtest +def test_invalid_column_name(executor, exception_formatter): + result = run( + executor, "select invalid command", exception_formatter=exception_formatter + ) + assert 'column "invalid" does not exist' in result[0] + + +@pytest.fixture(params=[True, False]) +def expanded(request): + return request.param + + +@dbtest +def test_unicode_support_in_output(executor, expanded): + run(executor, "create table unicodechars(t text)") + run(executor, "insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + assert "é" in run( + executor, "select * from unicodechars", join=True, expanded=expanded + ) + + +@dbtest +def test_not_is_special(executor, pgspecial): + """is_special is set to false for database queries.""" + query = "select 1" + result = list(executor.run(query, pgspecial=pgspecial)) + success, is_special = result[0][5:] + assert success == True + assert is_special == False + + +@dbtest +def test_execute_from_file_no_arg(executor, pgspecial): + r"""\i without a filename returns an error.""" + result = list(executor.run(r"\i", pgspecial=pgspecial)) + status, sql, success, is_special = result[0][3:] + assert "missing required argument" in status + assert success == False + assert is_special == True + + +@dbtest +@patch("pgcli.main.os") +def test_execute_from_file_io_error(os, executor, pgspecial): + r"""\i with an os_error returns an error.""" + # Inject an OSError. + os.path.expanduser.side_effect = OSError("test") + + # Check the result. + result = list(executor.run(r"\i test", pgspecial=pgspecial)) + status, sql, success, is_special = result[0][3:] + assert status == "test" + assert success == False + assert is_special == True + + +@dbtest +def test_execute_from_commented_file_that_executes_another_file( + executor, pgspecial, tmpdir +): + # https://github.com/dbcli/pgcli/issues/1336 + sqlfile1 = tmpdir.join("test01.sql") + sqlfile1.write("-- asdf \n\\h") + sqlfile2 = tmpdir.join("test00.sql") + sqlfile2.write("--An useless comment;\nselect now();\n-- another useless comment") + + rcfile = str(tmpdir.join("rcfile")) + print(rcfile) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + assert cli != None + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result != None + assert result[0].find("ALTER TABLE") + + +@dbtest +def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + # https://github.com/dbcli/pgcli/issues/1362 + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "--comment1\n--comment2\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + \h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + /*comment 3 + comment4*/ + \\h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment\ncomment line2*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment\ncomment line2*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """\\h /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + print(result) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: we probably don't want to do this but sqlparse is not parsing things well + # we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/ + # style comments after command + + statement = """/*comment1*/ + \h + /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: same for this one + statement = """/*comment1 + comment3 + comment2*/ + \\h + /*comment4 + comment5 + comment6*/""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + +@dbtest +def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir): + # https://github.com/dbcli/pgcli/issues/1403 + + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + # this simulates the original error (1403) without having to add/drop tables + # since it was just an error on reading input files and not the actual + # command itself + + # test that the statement works + statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a \n in the middle + statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a newline in the middle + statement = """VALUES (1, 'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # now add a single comment line + statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + VALUES (1,'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # two comment lines + statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + --comment2 + VALUES (1,'one'), (2, 'two'), (3, 'three'); + """ + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + # + comments after the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three'); +--comment4 +--comment5""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + +@dbtest +def test_multiple_queries_same_line(executor): + result = run(executor, "select 'foo'; select 'bar'") + assert len(result) == 12 # 2 * (output+status) * 3 lines + assert "foo" in result[3] + assert "bar" in result[9] + + +@dbtest +def test_multiple_queries_with_special_command_same_line(executor, pgspecial): + result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial) + assert len(result) == 11 # 2 * (output+status) * 3 lines + assert "foo" in result[3] + # This is a lame check. :( + assert "Schema" in result[7] + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter): + result = run( + executor, + "select 'fooé'; invalid syntax é", + exception_formatter=exception_formatter, + ) + assert "fooé" in result[3] + assert 'syntax error at or near "invalid"' in result[-1] + + +@pytest.fixture +def pgspecial(): + return PGCli().pgspecial + + +@dbtest +def test_special_command_help(executor, pgspecial): + result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|") + assert "Command" in result[1] + assert "Description" in result[2] + + +@dbtest +def test_bytea_field_support_in_output(executor): + run(executor, "create table binarydata(c bytea)") + run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))") + + assert "\\xdeadbeef" in run(executor, "select * from binarydata", join=True) + + +@dbtest +def test_unicode_support_in_unknown_type(executor): + assert "日本語" in run(executor, "SELECT '日本語' AS japanese;", join=True) + + +@dbtest +def test_unicode_support_in_enum_type(executor): + run(executor, "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')") + run(executor, "CREATE TABLE person (name TEXT, current_mood mood)") + run(executor, "INSERT INTO person VALUES ('Moe', '日本語')") + assert "日本語" in run(executor, "SELECT * FROM person", join=True) + + +@requires_json +def test_json_renders_without_u_prefix(executor, expanded): + run(executor, "create table jsontest(d json)") + run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""") + result = run( + executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded + ) + + assert '{"name": "Éowyn"}' in result + + +@requires_jsonb +def test_jsonb_renders_without_u_prefix(executor, expanded): + run(executor, "create table jsonbtest(d jsonb)") + run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""") + result = run( + executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded + ) + + assert '{"name": "Éowyn"}' in result + + +@dbtest +def test_date_time_types(executor): + run(executor, "SET TIME ZONE UTC") + assert ( + run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] + == "| 00:00:00 |" + ) + assert ( + run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split( + "\n" + )[3] + == "| 00:00:00+14:59 |" + ) + assert ( + run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[ + 3 + ] + == "| 4713-01-01 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True + ).split("\n")[3] + == "| 4713-01-01 00:00:00 BC |" + ) + assert ( + run( + executor, + "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))", + join=True, + ).split("\n")[3] + == "| 4713-01-01 00:00:00+00 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True + ).split("\n")[3] + == "| -123456789 days, 12:23:56 |" + ) + + +@dbtest +@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"]) +def test_large_numbers_render_directly(executor, value): + run(executor, "create table numbertest(a numeric)") + run(executor, f"insert into numbertest (a) values ({value})") + assert value in run(executor, "select * from numbertest", join=True) + + +@dbtest +@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"]) +@pytest.mark.parametrize("verbose", ["", "+"]) +@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"]) +def test_describe_special(executor, command, verbose, pattern, pgspecial): + # We don't have any tests for the output of any of the special commands, + # but we can at least make sure they run without error + sql = r"\{command}{verbose} {pattern}".format(**locals()) + list(executor.run(sql, pgspecial=pgspecial)) + + +@dbtest +@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"]) +def test_raises_with_no_formatter(executor, sql): + with pytest.raises(psycopg.ProgrammingError): + list(executor.run(sql)) + + +@dbtest +def test_on_error_resume(executor, exception_formatter): + sql = "select 1; error; select 1;" + result = list( + executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter) + ) + assert len(result) == 3 + + +@dbtest +def test_on_error_stop(executor, exception_formatter): + sql = "select 1; error; select 1;" + result = list( + executor.run( + sql, on_error_resume=False, exception_formatter=exception_formatter + ) + ) + assert len(result) == 2 + + +# @dbtest +# def test_unicode_notices(executor): +# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;" +# result = list(executor.run(sql)) +# assert result[0][0] == u'NOTICE: 有人更改\n' + + +@dbtest +def test_nonexistent_function_definition(executor): + with pytest.raises(RuntimeError): + result = executor.view_definition("there_is_no_such_function") + + +@dbtest +def test_function_definition(executor): + run( + executor, + """ + CREATE OR REPLACE FUNCTION public.the_number_three() + RETURNS int + LANGUAGE sql + AS $function$ + select 3; + $function$ + """, + ) + result = executor.function_definition("the_number_three") + + +@dbtest +def test_view_definition(executor): + run(executor, "create table tbl1 (a text, b numeric)") + run(executor, "create view vw1 AS SELECT * FROM tbl1") + run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1") + result = executor.view_definition("vw1") + assert 'VIEW "public"."vw1" AS' in result + assert "FROM tbl1" in result + # import pytest; pytest.set_trace() + result = executor.view_definition("mvw1") + assert "MATERIALIZED VIEW" in result + + +@dbtest +def test_nonexistent_view_definition(executor): + with pytest.raises(RuntimeError): + result = executor.view_definition("there_is_no_such_view") + with pytest.raises(RuntimeError): + result = executor.view_definition("mvw1") + + +@dbtest +def test_short_host(executor): + with patch.object(executor, "host", "localhost"): + assert executor.short_host == "localhost" + with patch.object(executor, "host", "localhost.example.org"): + assert executor.short_host == "localhost" + with patch.object( + executor, "host", "localhost1.example.org,localhost2.example.org" + ): + assert executor.short_host == "localhost1" + + +class VirtualCursor: + """Mock a cursor to virtual database like pgbouncer.""" + + def __init__(self): + self.protocol_error = False + self.protocol_message = "" + self.description = None + self.status = None + self.statusmessage = "Error" + + def execute(self, *args, **kwargs): + self.protocol_error = True + self.protocol_message = "Command not supported" + + +@dbtest +def test_exit_without_active_connection(executor): + quit_handler = MagicMock() + pgspecial = PGSpecial() + pgspecial.register( + quit_handler, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + + with patch.object( + executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!") + ): + # we should be able to quit the app, even without active connection + run(executor, "\\q", pgspecial=pgspecial) + quit_handler.assert_called_once() + + # an exception should be raised when running a query without active connection + with pytest.raises(psycopg.InterfaceError): + run(executor, "select 1", pgspecial=pgspecial) + + +@dbtest +def test_virtual_database(executor): + virtual_connection = MagicMock() + virtual_connection.cursor.return_value = VirtualCursor() + with patch.object(executor, "conn", virtual_connection): + result = run(executor, "select 1") + assert "Command not supported" in result diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py new file mode 100644 index 0000000..cd99e32 --- /dev/null +++ b/tests/test_pgspecial.py @@ -0,0 +1,78 @@ +import pytest +from pgcli.packages.sqlcompletion import ( + suggest_type, + Special, + Database, + Schema, + Table, + View, + Function, + Datatype, +) + + +def test_slash_suggests_special(): + suggestions = suggest_type("\\", "\\") + assert set(suggestions) == {Special()} + + +def test_slash_d_suggests_special(): + suggestions = suggest_type("\\d", "\\d") + assert set(suggestions) == {Special()} + + +def test_dn_suggests_schemata(): + suggestions = suggest_type("\\dn ", "\\dn ") + assert suggestions == (Schema(),) + + suggestions = suggest_type("\\dn xxx", "\\dn xxx") + assert suggestions == (Schema(),) + + +def test_d_suggests_tables_views_and_schemas(): + suggestions = suggest_type(r"\d ", r"\d ") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} + + suggestions = suggest_type(r"\d xxx", r"\d xxx") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} + + +def test_d_dot_suggests_schema_qualified_tables_or_views(): + suggestions = suggest_type(r"\d myschema.", r"\d myschema.") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} + + suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} + + +def test_df_suggests_schema_or_function(): + suggestions = suggest_type("\\df xxx", "\\df xxx") + assert set(suggestions) == {Function(schema=None, usage="special"), Schema()} + + suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx") + assert suggestions == (Function(schema="myschema", usage="special"),) + + +def test_leading_whitespace_ok(): + cmd = "\\dn " + whitespace = " " + suggestions = suggest_type(whitespace + cmd, whitespace + cmd) + assert suggestions == suggest_type(cmd, cmd) + + +def test_dT_suggests_schema_or_datatypes(): + text = "\\dT " + suggestions = suggest_type(text, text) + assert set(suggestions) == {Schema(), Datatype(schema=None)} + + +def test_schema_qualified_dT_suggests_datatypes(): + text = "\\dT foo." + suggestions = suggest_type(text, text) + assert suggestions == (Datatype(schema="foo"),) + + +@pytest.mark.parametrize("command", ["\\c ", "\\connect "]) +def test_c_suggests_databases(command): + suggestions = suggest_type(command, command) + assert suggestions == (Database(),) diff --git a/tests/test_prioritization.py b/tests/test_prioritization.py new file mode 100644 index 0000000..f5b6700 --- /dev/null +++ b/tests/test_prioritization.py @@ -0,0 +1,20 @@ +from pgcli.packages.prioritization import PrevalenceCounter + + +def test_prevalence_counter(): + counter = PrevalenceCounter() + sql = """SELECT * FROM foo WHERE bar GROUP BY baz; + select * from foo; + SELECT * FROM foo WHERE bar GROUP + BY baz""" + counter.update(sql) + + keywords = ["SELECT", "FROM", "GROUP BY"] + expected = [3, 3, 2] + kw_counts = [counter.keyword_count(x) for x in keywords] + assert kw_counts == expected + assert counter.keyword_count("NOSUCHKEYWORD") == 0 + + names = ["foo", "bar", "baz"] + name_counts = [counter.name_count(x) for x in names] + assert name_counts == [3, 2, 2] diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py new file mode 100644 index 0000000..91abe37 --- /dev/null +++ b/tests/test_prompt_utils.py @@ -0,0 +1,17 @@ +import click + +from pgcli.packages.prompt_utils import confirm_destructive_query + + +def test_confirm_destructive_query_notty(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, [], None) is None + + +def test_confirm_destructive_query_with_alias(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, ["drop"], "test") is None diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py new file mode 100644 index 0000000..da916b4 --- /dev/null +++ b/tests/test_rowlimit.py @@ -0,0 +1,79 @@ +import pytest +from unittest.mock import Mock + +from pgcli.main import PGCli + + +# We need this fixtures because we need PGCli object to be created +# after test collection so it has config loaded from temp directory + + +@pytest.fixture(scope="module") +def default_pgcli_obj(): + return PGCli() + + +@pytest.fixture(scope="module") +def DEFAULT(default_pgcli_obj): + return default_pgcli_obj.row_limit + + +@pytest.fixture(scope="module") +def LIMIT(DEFAULT): + return DEFAULT + 1000 + + +@pytest.fixture(scope="module") +def over_default(DEFAULT): + over_default_cursor = Mock() + over_default_cursor.configure_mock(rowcount=DEFAULT + 10) + return over_default_cursor + + +@pytest.fixture(scope="module") +def over_limit(LIMIT): + over_limit_cursor = Mock() + over_limit_cursor.configure_mock(rowcount=LIMIT + 10) + return over_limit_cursor + + +@pytest.fixture(scope="module") +def low_count(): + low_count_cursor = Mock() + low_count_cursor.configure_mock(rowcount=1) + return low_count_cursor + + +def test_row_limit_with_LIMIT_clause(LIMIT, over_limit): + cli = PGCli(row_limit=LIMIT) + stmt = "SELECT * FROM students LIMIT 1000" + + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + +def test_row_limit_without_LIMIT_clause(LIMIT, over_limit): + cli = PGCli(row_limit=LIMIT) + stmt = "SELECT * FROM students" + + result = cli._should_limit_output(stmt, over_limit) + assert result is True + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + +def test_row_limit_on_non_select(over_limit): + cli = PGCli() + stmt = "UPDATE students SET name='Boby'" + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py new file mode 100644 index 0000000..5c9c9af --- /dev/null +++ b/tests/test_smart_completion_multiple_schemata.py @@ -0,0 +1,757 @@ +import itertools +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + schema, + table, + function, + wildcard_expansion, + column, + get_result, + result_set, + qual, + no_qual, + parametrize, +) +from utils import completions_to_set + +metadata = { + "tables": { + "public": { + "users": ["id", "email", "first_name", "last_name"], + "orders": ["id", "ordered_date", "status", "datestamp"], + "select": ["id", "localtime", "ABC"], + }, + "custom": { + "users": ["id", "phone_number"], + "Users": ["userid", "username"], + "products": ["id", "product_name", "price"], + "shipments": ["id", "address", "user_id"], + }, + "Custom": {"projects": ["projectid", "name"]}, + "blog": { + "entries": ["entryid", "entrytitle", "entrytext"], + "tags": ["tagid", "name"], + "entrytags": ["entryid", "tagid"], + "entacclog": ["entryid", "username", "datestamp"], + }, + }, + "functions": { + "public": [ + ["func1", [], [], [], "", False, False, False, False], + ["func2", [], [], [], "", False, False, False, False], + ], + "custom": [ + ["func3", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x"], + ["integer"], + ["o"], + "integer", + False, + False, + True, + False, + ], + ], + "Custom": [["func4", [], [], [], "", False, False, False, False]], + "blog": [ + [ + "extract_entry_symbols", + ["_entryid", "symbol"], + ["integer", "text"], + ["i", "o"], + "", + False, + False, + True, + False, + ], + [ + "enter_entry", + ["_title", "_text", "entryid"], + ["text", "text", "integer"], + ["i", "i", "o"], + "", + False, + False, + False, + False, + ], + ], + }, + "datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]}, + "foreignkeys": { + "custom": [("public", "users", "id", "custom", "shipments", "user_id")], + "blog": [ + ("blog", "entries", "entryid", "blog", "entacclog", "entryid"), + ("blog", "entries", "entryid", "blog", "entrytags", "entryid"), + ("blog", "tags", "tagid", "blog", "entrytags", "tagid"), + ], + }, + "defaults": { + "public": { + ("orders", "id"): "nextval('orders_id_seq'::regclass)", + ("orders", "datestamp"): "now()", + ("orders", "status"): "'PENDING'::text", + } + }, +} + +testdata = MetaData(metadata) +cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')] +casing = ( + "SELECT", + "Orders", + "User_Emails", + "CUSTOM", + "Func1", + "Entries", + "Tags", + "EntryTags", + "EntAccLog", + "EntryID", + "EntryTitle", + "EntryText", +) +completers = testdata.get_completers(casing) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("table", ["users", '"users"']) +def test_suggested_column_names_from_shadowed_visible_table(completer, table): + result = get_result(completer, "SELECT FROM " + table, len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT from custom.users", + "WITH users as (SELECT 1 AS foo) SELECT from custom.users", + ], +) +def test_suggested_column_names_from_qualified_shadowed_table(completer, text): + result = get_result(completer, text, position=text.find(" ") + 1) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"]) +def test_suggested_column_names_from_cte(completer, text): + result = completions_to_set(get_result(completer, text, text.find(" ") + 1)) + assert result == completions_to_set( + [column("foo")] + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users JOIN custom.shipments ON ", + """SELECT * + FROM public.users + JOIN custom.shipments ON """, + ], +) +def test_suggested_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + alias("users"), + alias("shipments"), + name_join("shipments.id = users.id"), + fk_join("shipments.user_id = users.id"), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize( + ("query", "tbl"), + itertools.product( + ( + "SELECT * FROM public.{0} RIGHT OUTER JOIN ", + """SELECT * + FROM {0} + JOIN """, + ), + ("users", '"users"', "Users"), + ), +) +def test_suggested_joins(completer, query, tbl): + result = get_result(completer, query.format(tbl)) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_column_names_from_schema_qualifed_table(completer): + result = get_result(completer, "SELECT from custom.products", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize( + "text", + [ + "INSERT INTO orders(", + "INSERT INTO orders (", + "INSERT INTO public.orders(", + "INSERT INTO public.orders (", + ], +) +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_columns_with_insert(completer, text): + assert completions_to_set(get_result(completer, text)) == completions_to_set( + testdata.columns("orders") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_column_names_in_function(completer): + result = get_result( + completer, "SELECT MAX( from custom.products", len("SELECT MAX(") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + ["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'], +) +@parametrize("use_leading_double_quote", [False, True]) +def test_suggested_table_names_with_schema_dot( + completer, text, use_leading_double_quote +): + if use_leading_double_quote: + text += '"' + start_position = -1 + else: + start_position = 0 + + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.from_clause_items("custom", start_position) + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ['SELECT * FROM "Custom".']) +@parametrize("use_leading_double_quote", [False, True]) +def test_suggested_table_names_with_schema_dot2( + completer, text, use_leading_double_quote +): + if use_leading_double_quote: + text += '"' + start_position = -1 + else: + start_position = 0 + + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.from_clause_items("Custom", start_position) + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_column_names_with_qualified_alias(completer): + result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p.")) + assert completions_to_set(result) == completions_to_set( + testdata.columns("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_multiple_column_names(completer): + result = get_result( + completer, "SELECT id, from custom.products", len("SELECT id, ") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_multiple_column_names_with_alias(completer): + result = get_result( + completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ", + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id", + ], +) +def test_suggestions_after_on(completer, text): + position = len( + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON " + ) + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + alias("x"), + alias("y"), + name_join("y.price = x.price"), + name_join("y.product_name = x.product_name"), + name_join("y.id = x.id"), + ] + ) + + +@parametrize("completer", completers()) +def test_suggested_aliases_after_on_right_side(completer): + text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")]) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_table_names_after_from(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_schema_qualified_function_name(completer): + text = "SELECT custom.func" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("func3()", -len("func")), + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_schema_qualified_function_name_after_from(completer): + text = "SELECT * FROM custom.set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_not_returned(completer): + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set([]) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_in_search_path(completer): + completer.search_path = ["public", "custom"] + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT 1::custom.", + "CREATE TABLE foo (bar custom.", + "CREATE FUNCTION foo (bar INT, baz custom.", + "ALTER TABLE foo ALTER COLUMN bar TYPE custom.", + ], +) +def test_schema_qualified_type_name(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set(testdata.types("custom")) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggest_columns_from_aliased_set_returning_function(completer): + result = get_result( + completer, "select f. from custom.set_returning_func() f", len("select f.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", "custom", "functions") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM custom.set_returning_func()", + "SELECT * FROM Custom.set_returning_func()", + "SELECT * FROM Custom.Set_Returning_Func()", + ], +) +def test_wildcard_column_expansion_with_function(completer, text): + position = len("SELECT *") + + completions = get_result(completer, text, position) + + col_list = "x" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_alias_qualifier(completer): + text = "SELECT p.* FROM custom.products p" + position = len("SELECT p.*") + + completions = get_result(completer, text, position) + + col_list = "id, p.product_name, p.price" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + """ + SELECT count(1) FROM users; + CREATE FUNCTION foo(custom.products _products) returns custom.shipments + LANGUAGE SQL + AS $foo$ + SELECT 1 FROM custom.shipments; + INSERT INTO public.orders(*) values(-1, now(), 'preliminary'); + SELECT 2 FROM custom.users; + $foo$; + SELECT count(1) FROM custom.shipments; + """, + "INSERT INTO public.orders(*", + "INSERT INTO public.Orders(*", + "INSERT INTO public.orders (*", + "INSERT INTO public.Orders (*", + "INSERT INTO orders(*", + "INSERT INTO Orders(*", + "INSERT INTO orders (*", + "INSERT INTO Orders (*", + "INSERT INTO public.orders(*)", + "INSERT INTO public.Orders(*)", + "INSERT INTO public.orders (*)", + "INSERT INTO public.Orders (*)", + "INSERT INTO orders(*)", + "INSERT INTO Orders(*)", + "INSERT INTO orders (*)", + "INSERT INTO Orders (*)", + ], +) +def test_wildcard_column_expansion_with_insert(completer, text): + position = text.index("*") + 1 + completions = get_result(completer, text, position) + + expected = [wildcard_expansion("ordered_date, status")] + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_table_qualifier(completer): + text = 'SELECT "select".* FROM public."select"' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select"."localtime", "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=qual)) +def test_wildcard_column_expansion_with_two_tables(completer): + text = 'SELECT * FROM public."select" JOIN custom.users ON true' + position = len("SELECT *") + + completions = get_result(completer, text, position) + + cols = ( + '"select".id, "select"."localtime", "select"."ABC", ' + "users.id, users.phone_number" + ) + expected = [wildcard_expansion(cols)] + assert completions == expected + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_two_tables_and_parent(completer): + text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select"."localtime", "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT U. FROM custom.Users U", + "SELECT U. FROM custom.USERS U", + "SELECT U. FROM custom.users U", + 'SELECT U. FROM "custom".Users U', + 'SELECT U. FROM "custom".USERS U', + 'SELECT U. FROM "custom".users U', + ], +) +def test_suggest_columns_from_unquoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + testdata.columns("users", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U'] +) +def test_suggest_columns_from_quoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + testdata.columns("Users", "custom") + ) + + +texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "] + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize("text", texts) +def test_schema_or_visible_table_completion(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(aliasing=True, casing=False, filtr=True)) +@parametrize("text", texts) +def test_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("users u"), + table("orders o" if text == "SELECT * FROM " else "orders o2"), + table('"select" s'), + function("func1() f"), + function("func2() f"), + ] + ) + + +@parametrize("completer", completers(aliasing=True, casing=True, filtr=True)) +@parametrize("text", texts) +def test_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users u"), + table("Orders O" if text == "SELECT * FROM " else "Orders O2"), + table('"select" s'), + function("Func1() F"), + function("func2() f"), + ] + ) + + +@parametrize("completer", completers(aliasing=False, casing=True, filtr=True)) +@parametrize("text", texts) +def test_table_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users"), + table("Orders"), + table('"select"'), + function("Func1()"), + function("func2()"), + ] + ) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_alias_search_without_aliases2(completer): + text = "SELECT * FROM blog.et" + result = get_result(completer, text) + assert result[0] == table("EntryTags", -2) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_alias_search_without_aliases1(completer): + text = "SELECT * FROM blog.e" + result = get_result(completer, text) + assert result[0] == table("Entries", -1) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_alias_search_with_aliases2(completer): + text = "SELECT * FROM blog.et" + result = get_result(completer, text) + assert result[0] == table("EntryTags ET", -2) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_alias_search_with_aliases1(completer): + text = "SELECT * FROM blog.e" + result = get_result(completer, text) + assert result[0] == table("Entries E", -1) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_join_alias_search_with_aliases1(completer): + text = "SELECT * FROM blog.Entries E JOIN blog.e" + result = get_result(completer, text) + assert result[:2] == [ + table("Entries E2", -1), + join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1), + ] + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases1(completer): + text = "SELECT * FROM blog.Entries JOIN blog.e" + result = get_result(completer, text) + assert result[:2] == [ + table("Entries", -1), + join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1), + ] + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_join_alias_search_with_aliases2(completer): + text = "SELECT * FROM blog.Entries E JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases2(completer): + text = "SELECT * FROM blog.Entries JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2) + + +@parametrize("completer", completers()) +def test_function_alias_search_without_aliases(completer): + text = "SELECT blog.ees" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -3 + assert first.text == "extract_entry_symbols()" + assert first.display_text == "extract_entry_symbols(_entryid)" + + +@parametrize("completer", completers()) +def test_function_alias_search_with_aliases(completer): + text = "SELECT blog.ee" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -2 + assert first.text == "enter_entry(_title := , _text := )" + assert first.display_text == "enter_entry(_title, _text)" + + +@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual)) +def test_column_alias_search(completer): + result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et")) + cols = ("EntryText", "EntryTitle", "EntryID") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=True)) +def test_column_alias_search_qualified(completer): + result = get_result( + completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei") + ) + cols = ("EntryID", "EntryTitle") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) +def test_schema_object_order(completer): + result = get_result(completer, "SELECT * FROM u") + assert result[:3] == [ + table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users") + ] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) +def test_all_schema_objects(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("orders", '"select"', "custom.shipments")] + + [function(x + "()") for x in ("func2",)] + ) + + +@parametrize("completer", completers(filtr=False, aliasing=False, casing=True)) +def test_all_schema_objects_with_casing(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")] + + [function(x + "()") for x in ("func2",)] + ) + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) +def test_all_schema_objects_with_aliases(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("orders o", '"select" s', "custom.shipments s")] + + [function(x) for x in ("func2() f",)] + ) + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) +def test_set_schema(completer): + text = "SET SCHEMA " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")] + ) diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py new file mode 100644 index 0000000..db1fe0a --- /dev/null +++ b/tests/test_smart_completion_public_schema_only.py @@ -0,0 +1,1112 @@ +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + keyword, + schema, + table, + view, + function, + column, + wildcard_expansion, + get_result, + result_set, + qual, + no_qual, + parametrize, +) +from prompt_toolkit.completion import Completion +from utils import completions_to_set + + +metadata = { + "tables": { + "users": ["id", "parentid", "email", "first_name", "last_name"], + "Users": ["userid", "username"], + "orders": ["id", "ordered_date", "status", "email"], + "select": ["id", "insert", "ABC"], + }, + "views": {"user_emails": ["id", "email"], "functions": ["function"]}, + "functions": [ + ["custom_fun", [], [], [], "", False, False, False, False], + ["_custom_fun", [], [], [], "", False, False, False, False], + ["custom_func1", [], [], [], "", False, False, False, False], + ["custom_func2", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x", "y"], + ["integer", "integer"], + ["b", "b"], + "", + False, + False, + True, + False, + ], + ], + "datatypes": ["custom_type1", "custom_type2"], + "foreignkeys": [ + ("public", "users", "id", "public", "users", "parentid"), + ("public", "users", "id", "public", "Users", "userid"), + ], +} + +metadata = {k: {"public": v} for k, v in metadata.items()} + +testdata = MetaData(metadata) + +cased_users_col_names = ["ID", "PARENTID", "Email", "First_Name", "last_name"] +cased_users2_col_names = ["UserID", "UserName"] +cased_func_names = [ + "Custom_Fun", + "_custom_fun", + "Custom_Func1", + "custom_func2", + "set_returning_func", +] +cased_tbls = ["Users", "Orders"] +cased_views = ["User_Emails", "Functions"] +casing = ( + ["SELECT", "PUBLIC"] + + cased_func_names + + cased_tbls + + cased_views + + cased_users_col_names + + cased_users2_col_names +) +# Lists for use in assertions +cased_funcs = [ + function(f) + for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()") +] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")] +cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])] +cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls +cased_users_cols = [column(c) for c in cased_users_col_names] +aliased_rels = ( + [table(t) for t in ("users u", '"Users" U', "orders o", '"select" s')] + + [view("user_emails ue"), view("functions f")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "custom_fun() cf", + "custom_func1() cf", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) +cased_aliased_rels = ( + [table(t) for t in ("Users U", '"Users" U', "Orders O", '"select" s')] + + [view("User_Emails UE"), view("Functions F")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "Custom_Fun() CF", + "Custom_Func1() CF", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) +completers = testdata.get_completers(casing) + + +# Just to make sure that this doesn't crash +@parametrize("completer", completers()) +def test_function_column_name(completer): + for l in range( + len("SELECT * FROM Functions WHERE function:"), + len("SELECT * FROM Functions WHERE function:text") + 1, + ): + assert [] == get_result( + completer, "SELECT * FROM Functions WHERE function:text"[:l] + ) + + +@parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"]) +@parametrize("completer", completers()) +def test_drop_alter_function(completer, action): + assert get_result(completer, action + " FUNCTION set_ret") == [ + function("set_returning_func(x integer, y integer)", -len("set_ret")) + ] + + +@parametrize("completer", completers()) +def test_empty_string_completion(completer): + result = get_result(completer, "") + assert completions_to_set( + testdata.keywords() + testdata.specials() + ) == completions_to_set(result) + + +@parametrize("completer", completers()) +def test_select_keyword_completion(completer): + result = get_result(completer, "SEL") + assert completions_to_set(result) == completions_to_set([keyword("SELECT", -3)]) + + +@parametrize("completer", completers()) +def test_builtin_function_name_completion(completer): + result = get_result(completer, "SELECT MA") + assert completions_to_set(result) == completions_to_set( + [ + function("MAKE_DATE", -2), + function("MAKE_INTERVAL", -2), + function("MAKE_TIME", -2), + function("MAKE_TIMESTAMP", -2), + function("MAKE_TIMESTAMPTZ", -2), + function("MASKLEN", -2), + function("MAX", -2), + keyword("MAXEXTENTS", -2), + keyword("MATERIALIZED VIEW", -2), + ] + ) + + +@parametrize("completer", completers()) +def test_builtin_function_matches_only_at_start(completer): + text = "SELECT IN" + + result = [c.text for c in get_result(completer, text)] + + assert "MIN" not in result + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion(completer): + result = get_result(completer, "SELECT cu") + assert completions_to_set(result) == completions_to_set( + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + function("CURRENT_DATE", -2), + function("CURRENT_TIMESTAMP", -2), + function("CUME_DIST", -2), + function("CURRENT_TIME", -2), + keyword("CURRENT", -2), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion_matches_anywhere(completer): + result = get_result(completer, "SELECT om") + assert completions_to_set(result) == completions_to_set( + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + ] + ) + + +@parametrize("completer", completers(casing=True)) +def test_list_functions_for_special(completer): + result = get_result(completer, r"\df ") + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + [function(f) for f in cased_func_names] + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_column_names_from_visible_table(completer): + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=True, qualify=no_qual)) +def test_suggested_cased_column_names(completer): + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + cased_funcs + + cased_users_cols + + testdata.builtin_functions() + + testdata.keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize("text", ["SELECT from users", "INSERT INTO Orders SELECT from users"]) +def test_suggested_auto_qualified_column_names(text, completer): + position = text.index(" ") + 1 + cols = [column(c.lower()) for c in cased_users_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cols + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=qual)) +@parametrize( + "text", + [ + 'SELECT from users U NATURAL JOIN "Users"', + 'INSERT INTO Orders SELECT from users U NATURAL JOIN "Users"', + ], +) +def test_suggested_auto_qualified_column_names_two_tables(text, completer): + position = text.index(" ") + 1 + cols = [column("U." + c.lower()) for c in cased_users_col_names] + cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cols + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=True, qualify=["always"])) +@parametrize("text", ["UPDATE users SET ", "INSERT INTO users("]) +def test_no_column_qualification(text, completer): + cols = [column(c) for c in cased_users_col_names] + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set(cols) + + +@parametrize("completer", completers(casing=True, qualify=["always"])) +def test_suggested_cased_always_qualified_column_names(completer): + text = "SELECT from users" + position = len("SELECT ") + cols = [column("users." + c) for c in cased_users_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cased_funcs + cols + testdata.builtin_functions() + testdata.keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_column_names_in_function(completer): + result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX(")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_column_names_with_table_dot(completer): + result = get_result(completer, "SELECT users. from users", len("SELECT users.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_column_names_with_alias(completer): + result = get_result(completer, "SELECT u. from users u", len("SELECT u.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_multiple_column_names(completer): + result = get_result(completer, "SELECT id, from users u", len("SELECT id, ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_multiple_column_names_with_alias(completer): + result = get_result( + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=True)) +def test_suggested_cased_column_names_with_alias(completer): + result = get_result( + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set(cased_users_cols) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_multiple_column_names_with_dot(completer): + result = get_result( + completer, + "SELECT users.id, users. from users u", + len("SELECT users.id, users."), + ) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_after_three_way_join(completer): + text = """SELECT * FROM users u1 + INNER JOIN users u2 ON u1.id = u2.id + INNER JOIN users u3 ON u2.id = u3.""" + result = get_result(completer, text) + assert column("id") in result + + +join_condition_texts = [ + 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ', + """INSERT INTO public.orders(orderid) + SELECT * FROM users U JOIN "Users" U2 ON """, + 'SELECT * FROM users U JOIN "Users" U2 ON ', + 'SELECT * FROM users U INNER join "Users" U2 ON ', + 'SELECT * FROM USERS U right JOIN "Users" U2 ON ', + 'SELECT * FROM users U LEFT JOIN "Users" U2 ON ', + 'SELECT * FROM Users U FULL JOIN "Users" U2 ON ', + 'SELECT * FROM users U right outer join "Users" U2 ON ', + 'SELECT * FROM Users U LEFT OUTER JOIN "Users" U2 ON ', + 'SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON ', + """SELECT * + FROM users U + FULL OUTER JOIN "Users" U2 ON + """, +] + + +@parametrize("completer", completers(casing=False)) +@parametrize("text", join_condition_texts) +def test_suggested_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("U"), alias("U2"), fk_join("U2.userid = U.id")] + ) + + +@parametrize("completer", completers(casing=True)) +@parametrize("text", join_condition_texts) +def test_cased_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + """SELECT * + FROM users + CROSS JOIN "Users" + NATURAL JOIN users u + JOIN "Users" u2 ON + """ + ], +) +def test_suggested_join_conditions_with_same_table_twice(completer, text): + result = get_result(completer, text) + assert result == [ + fk_join("u2.userid = u.id"), + fk_join("u2.userid = users.id"), + name_join('u2.userid = "Users".userid'), + name_join('u2.username = "Users".username'), + alias("u"), + alias("u2"), + alias("users"), + alias('"Users"'), + ] + + +@parametrize("completer", completers()) +@parametrize("text", ["SELECT * FROM users JOIN users u2 on foo."]) +def test_suggested_join_conditions_with_invalid_qualifier(completer, text): + result = get_result(completer, text) + assert result == [] + + +@parametrize("completer", completers(casing=False)) +@parametrize( + ("text", "ref"), + [ + ("SELECT * FROM users JOIN NonTable on ", "NonTable"), + ("SELECT * FROM users JOIN nontable nt on ", "nt"), + ], +) +def test_suggested_join_conditions_with_invalid_table(completer, text, ref): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("users"), alias(ref)] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM "Users" u JOIN u', + 'SELECT * FROM "Users" u JOIN uid', + 'SELECT * FROM "Users" u JOIN userid', + 'SELECT * FROM "Users" u JOIN id', + ], +) +def test_suggested_joins_fuzzy(completer, text): + result = get_result(completer, text) + last_word = text.split()[-1] + expected = join("users ON users.id = u.userid", -len(last_word)) + assert expected in result + + +join_texts = [ + "SELECT * FROM Users JOIN ", + """INSERT INTO "Users" + SELECT * + FROM Users + INNER JOIN """, + """INSERT INTO public."Users"(username) + SELECT * + FROM Users + INNER JOIN """, + """SELECT * + FROM Users + INNER JOIN """, +] + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", join_texts) +def test_suggested_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [ + join('"Users" ON "Users".userid = Users.id'), + join("users users2 ON users2.id = Users.parentid"), + join("users users2 ON users2.parentid = Users.id"), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", join_texts) +def test_cased_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + + cased_rels + + [ + join('"Users" ON "Users".UserID = Users.ID'), + join("Users Users2 ON Users2.ID = Users.PARENTID"), + join("Users Users2 ON Users2.PARENTID = Users.ID"), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", join_texts) +def test_aliased_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + aliased_rels + + [ + join('"Users" U ON U.userid = Users.id'), + join("users u ON u.id = Users.parentid"), + join("users u ON u.parentid = Users.id"), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM public."Users" JOIN ', + 'SELECT * FROM public."Users" RIGHT OUTER JOIN ', + """SELECT * + FROM public."Users" + LEFT JOIN """, + ], +) +def test_suggested_joins_quoted_schema_qualified_table(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [join('public.users ON users.id = "Users".userid')] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON ", + "SELECT u.name, o.id FROM users u JOIN orders o ON JOIN orders o2 ON", + ], +) +def test_suggested_aliases_after_on(completer, text): + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + alias("u"), + name_join("o.id = u.id"), + name_join("o.email = u.email"), + alias("o"), + ] + ) + + +@parametrize("completer", completers()) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ", + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = JOIN orders o2 ON", + ], +) +def test_suggested_aliases_after_on_right_side(completer, text): + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set([alias("u"), alias("o")]) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON ", + "SELECT users.name, orders.id FROM users JOIN orders ON JOIN orders orders2 ON", + ], +) +def test_suggested_tables_after_on(completer, text): + position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + name_join("orders.id = users.id"), + name_join("orders.email = users.email"), + alias("users"), + alias("orders"), + ] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = JOIN orders orders2 ON", + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ", + ], +) +def test_suggested_tables_after_on_right_side(completer, text): + position = len( + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " + ) + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [alias("users"), alias("orders")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (", + "SELECT * FROM users INNER JOIN orders USING(", + ], +) +def test_join_using_suggests_common_columns(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users u1 JOIN users u2 USING (email) JOIN user_emails ue USING()", + "SELECT * FROM users u1 JOIN users u2 USING(email) JOIN user_emails ue USING ()", + "SELECT * FROM users u1 JOIN user_emails ue USING () JOIN users u2 ue USING(first_name, last_name)", + "SELECT * FROM users u1 JOIN user_emails ue USING() JOIN users u2 ue USING (first_name, last_name)", + ], +) +def test_join_using_suggests_from_last_table(completer, text): + position = text.index("()") + 1 + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (id,", + "SELECT * FROM users INNER JOIN orders USING(id,", + ], +) +def test_join_using_suggests_columns_after_first_column(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + "SELECT * FROM ", + "SELECT * FROM users CROSS JOIN ", + "SELECT * FROM users natural join ", + ], +) +def test_table_names_after_from(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + assert [c.text for c in result] == [ + "public", + "orders", + '"select"', + "users", + '"Users"', + "functions", + "user_emails", + "_custom_fun()", + "custom_fun()", + "custom_func1()", + "custom_func2()", + "set_returning_func(x := , y := )", + ] + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_auto_escaped_col_names(completer): + result = get_result(completer, 'SELECT from "select"', len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("select") + ) + + +@parametrize("completer", completers(aliasing=False)) +def test_allow_leading_double_quote_in_last_word(completer): + result = get_result(completer, 'SELECT * from "sele') + + expected = table('"select"', -5) + + assert expected in result + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT 1::", + "CREATE TABLE foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "ALTER TABLE foo ALTER COLUMN bar TYPE ", + ], +) +def test_suggest_datatype(text, completer): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + testdata.types() + testdata.builtin_datatypes() + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_escaped_table_alias(completer): + result = get_result(completer, 'select * from "select" s where s.') + assert completions_to_set(result) == completions_to_set(testdata.columns("select")) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggest_columns_from_set_returning_function(completer): + result = get_result(completer, "select from set_returning_func()", len("select ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_aliased_set_returning_function(completer): + result = get_result( + completer, "select f. from set_returning_func() f", len("select f.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_join_functions_using_suggests_common_columns(completer): + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 USING (""" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_join_functions_on_suggests_columns_and_join_conditions(completer): + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 ON f1.""" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [name_join("y = f2.y"), name_join("x = f2.x")] + + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers()) +def test_learn_keywords(completer): + history = "CREATE VIEW v AS SELECT 1" + completer.extend_query_history(history) + + # Now that we've used `VIEW` once, it should be suggested ahead of other + # keywords starting with v. + text = "create v" + completions = get_result(completer, text) + assert completions[0].text == "VIEW" + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_learn_table_names(completer): + history = "SELECT * FROM users; SELECT * FROM orders; SELECT * FROM users" + completer.extend_query_history(history) + + text = "SELECT * FROM " + completions = get_result(completer, text) + + # `users` should be higher priority than `orders` (used more often) + users = table("users") + orders = table("orders") + + assert completions.index(users) < completions.index(orders) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_columns_before_keywords(completer): + text = "SELECT * FROM orders WHERE s" + completions = get_result(completer, text) + + col = column("status", -1) + kw = keyword("SELECT", -1) + + assert completions.index(col) < completions.index(kw) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM users", + "INSERT INTO users SELECT * FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT * + FROM users u""", + ], +) +def test_wildcard_column_expansion(completer, text): + position = text.find("*") + 1 + + completions = get_result(completer, text, position) + + col_list = "id, parentid, email, first_name, last_name" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.* FROM users u", + "INSERT INTO public.users SELECT u.* FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT u.* + FROM users u""", + ], +) +def test_wildcard_column_expansion_with_alias(completer, text): + position = text.find("*") + 1 + + completions = get_result(completer, text, position) + + col_list = "id, u.parentid, u.email, u.first_name, u.last_name" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text,expected", + [ + ( + "SELECT users.* FROM users", + "id, users.parentid, users.email, users.first_name, users.last_name", + ), + ( + "SELECT Users.* FROM Users", + "id, Users.parentid, Users.email, Users.first_name, Users.last_name", + ), + ], +) +def test_wildcard_column_expansion_with_table_qualifier(completer, text, expected): + position = len("SELECT users.*") + + completions = get_result(completer, text, position) + + expected = [wildcard_expansion(expected)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False, qualify=qual)) +def test_wildcard_column_expansion_with_two_tables(completer): + text = 'SELECT * FROM "select" JOIN users u ON true' + position = len("SELECT *") + + completions = get_result(completer, text, position) + + cols = ( + '"select".id, "select".insert, "select"."ABC", ' + "u.id, u.parentid, u.email, u.first_name, u.last_name" + ) + expected = [wildcard_expansion(cols)] + assert completions == expected + + +@parametrize("completer", completers(casing=False)) +def test_wildcard_column_expansion_with_two_tables_and_parent(completer): + text = 'SELECT "select".* FROM "select" JOIN users u ON true' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select".insert, "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + ["SELECT U. FROM Users U", "SELECT U. FROM USERS U", "SELECT U. FROM users U"], +) +def test_suggest_columns_from_unquoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_quoted_table(completer): + result = get_result(completer, 'SELECT U. FROM "Users" U', len("SELECT U.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("Users")) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "]) +def test_schema_or_visible_table_completion(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) +def test_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + aliased_rels + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) +def test_duplicate_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("orders o2"), + table("users u"), + table('"Users" U'), + table('"select" s'), + view("user_emails ue"), + view("functions f"), + function("_custom_fun() cf"), + function("custom_fun() cf"), + function("custom_func1() cf"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) +def test_duplicate_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + schema("PUBLIC"), + table("Orders O2"), + table("Users U"), + table('"Users" U'), + table('"select" s'), + view("User_Emails UE"), + view("Functions F"), + function("_custom_fun() cf"), + function("Custom_Fun() CF"), + function("Custom_Func1() CF"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) +def test_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + cased_aliased_rels + ) + + +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", ["SELECT * FROM "]) +def test_table_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + cased_rels + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "INSERT INTO users ()", + "INSERT INTO users()", + "INSERT INTO users () SELECT * FROM orders;", + "INSERT INTO users() SELECT * FROM users u cross join orders o", + ], +) +def test_insert(completer, text): + position = text.find("(") + 1 + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_suggest_cte_names(completer): + text = """ + WITH cte1 AS (SELECT a, b, c FROM foo), + cte2 AS (SELECT d, e, f FROM bar) + SELECT * FROM + """ + result = get_result(completer, text) + expected = completions_to_set( + [ + Completion("cte1", 0, display_meta="table"), + Completion("cte2", 0, display_meta="table"), + ] + ) + assert expected <= completions_to_set(result) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggest_columns_from_cte(completer): + result = get_result( + completer, + "WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte", + len("WITH cte AS (SELECT foo, bar FROM baz) SELECT "), + ) + expected = [ + Completion("foo", 0, display_meta="column"), + Completion("bar", 0, display_meta="column"), + ] + testdata.functions_and_keywords() + + assert completions_to_set(expected) == completions_to_set(result) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.", + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.", + ], +) +def test_cte_qualified_columns(completer, text): + result = get_result(completer, text) + expected = [Completion("foo", 0, display_meta="column")] + assert completions_to_set(expected) == completions_to_set(result) + + +@parametrize( + "keyword_casing,expected,texts", + [ + ("upper", "SELECT", ("", "s", "S", "Sel")), + ("lower", "select", ("", "s", "S", "Sel")), + ("auto", "SELECT", ("", "S", "SEL", "seL")), + ("auto", "select", ("s", "sel", "SEl")), + ], +) +def test_keyword_casing_upper(keyword_casing, expected, texts): + for text in texts: + completer = testdata.get_completer({"keyword_casing": keyword_casing}) + completions = get_result(completer, text) + assert expected in [cpl.text for cpl in completions] + + +@parametrize("completer", completers()) +def test_keyword_after_alter(completer): + text = "ALTER TABLE users ALTER " + expected = Completion("COLUMN", start_position=0, display_meta="keyword") + completions = get_result(completer, text) + assert expected in completions + + +@parametrize("completer", completers()) +def test_set_schema(completer): + text = "SET SCHEMA " + result = get_result(completer, text) + expected = completions_to_set([schema("'public'")]) + assert completions_to_set(result) == expected + + +@parametrize("completer", completers()) +def test_special_name_completion(completer): + result = get_result(completer, "\\t") + assert completions_to_set(result) == completions_to_set( + [ + Completion( + text="\\timing", + start_position=-2, + display_meta="Toggle timing of commands.", + ) + ] + ) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py new file mode 100644 index 0000000..1034bbe --- /dev/null +++ b/tests/test_sqlcompletion.py @@ -0,0 +1,964 @@ +from pgcli.packages.sqlcompletion import ( + suggest_type, + Special, + Database, + Schema, + Table, + Column, + View, + Keyword, + FromClauseItem, + Function, + Datatype, + Alias, + JoinCondition, + Join, +) +from pgcli.packages.parseutils.tables import TableReference +import pytest + + +def cols_etc( + table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None +): + """Returns the expected select-clause suggestions for a single-table + select.""" + return { + Column( + table_refs=(TableReference(schema, table, alias, is_function),), + qualifiable=True, + ), + Function(schema=parent), + Keyword(last_keyword), + } + + +def test_select_suggests_cols_with_visible_table_scope(): + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", last_keyword="SELECT") + + +def test_select_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", "sch", last_keyword="SELECT") + + +def test_cte_does_not_crash(): + sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;" + for i in range(len(sql)): + suggestions = suggest_type(sql[: i + 1], sql[: i + 1]) + + +@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE ']) +def test_where_suggests_columns_functions_quoted_table(expression): + expected = cols_etc("tabl", alias='"tabl"', last_keyword="WHERE") + suggestions = suggest_type(expression, expression) + assert expected == set(suggestions) + + +@pytest.mark.parametrize( + "expression", + [ + "INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ", + "INSERT INTO OtherTabl SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE foo = ", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) +def test_where_suggests_columns_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +@pytest.mark.parametrize( + "expression", + ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "], +) +def test_where_in_suggests_columns(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "]) +def test_after_as(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set() + + +def test_where_equals_any_suggests_columns_or_keywords(): + text = "SELECT * FROM tabl WHERE foo = ANY(" + suggestions = suggest_type(text, text) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +def test_lparen_suggests_cols_and_funcs(): + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert set(suggestion) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("("), + } + + +def test_select_suggests_cols_and_funcs(): + suggestions = suggest_type("SELECT ", "SELECT ") + assert set(suggestions) == { + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +@pytest.mark.parametrize( + "expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "] +) +def test_suggests_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()} + + +@pytest.mark.parametrize("expression", ["SELECT * FROM "]) +def test_suggest_tables_views_schemas_and_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ", + "SELECT * FROM foo JOIN bar USING (barid) JOIN ", + ], +) +def test_suggest_after_join_with_two_tables(expression): + suggestions = suggest_type(expression, expression) + tables = tuple([(None, "foo", None, False), (None, "bar", None, False)]) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(tables, None), + Schema(), + } + + +@pytest.mark.parametrize( + "expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"] +) +def test_suggest_after_join_with_one_table(expression): + suggestions = suggest_type(expression, expression) + tables = ((None, "foo", None, False),) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(((None, "foo", None, False),), None), + Schema(), + } + + +@pytest.mark.parametrize( + "expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."] +) +def test_suggest_qualified_tables_and_views(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} + + +@pytest.mark.parametrize("expression", ["UPDATE sch."]) +def test_suggest_qualified_aliasable_tables_and_views(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + 'SELECT * FROM sch."', + 'SELECT * FROM sch."foo', + 'SELECT * FROM "sch".', + 'SELECT * FROM "sch"."', + ], +) +def test_suggest_qualified_tables_views_and_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {FromClauseItem(schema="sch")} + + +@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."]) +def test_suggest_qualified_tables_views_functions_and_joins(expression): + suggestions = suggest_type(expression, expression) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestions) == { + FromClauseItem(schema="sch", table_refs=tbls), + Join(tbls, "sch"), + } + + +def test_truncate_suggests_tables_and_schemas(): + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert set(suggestions) == {Table(schema=None), Schema()} + + +def test_truncate_suggests_qualified_tables(): + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert set(suggestions) == {Table(schema="sch")} + + +@pytest.mark.parametrize( + "text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "] +) +def test_distinct_suggests_cols(text): + suggestions = suggest_type(text, text) + assert set(suggestions) == { + Column(table_refs=(), local_tables=(), qualifiable=True), + Function(schema=None), + Keyword("DISTINCT"), + } + + +@pytest.mark.parametrize( + "text, text_before, last_keyword", + [ + ("SELECT DISTINCT FROM tbl x JOIN tbl1 y", "SELECT DISTINCT", "SELECT"), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "ORDER BY", + ), + ], +) +def test_distinct_and_order_by_suggestions_with_aliases( + text, text_before, last_keyword +): + suggestions = suggest_type(text, text_before) + assert set(suggestions) == { + Column( + table_refs=( + TableReference(None, "tbl", "x", False), + TableReference(None, "tbl1", "y", False), + ), + local_tables=(), + qualifiable=True, + ), + Function(schema=None), + Keyword(last_keyword), + } + + +@pytest.mark.parametrize( + "text, text_before", + [ + ("SELECT DISTINCT x. FROM tbl x JOIN tbl1 y", "SELECT DISTINCT x."), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + ), + ], +) +def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before): + suggestions = suggest_type(text, text_before) + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } + + +def test_function_arguments_with_alias_given(): + suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.") + + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } + + +def test_col_comma_suggests_cols(): + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert set(suggestions) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +def test_table_comma_suggests_tables_and_schemas(): + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +def test_into_suggests_tables_and_schemas(): + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()} + + +@pytest.mark.parametrize( + "text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"] +) +def test_insert_into_lparen_suggests_cols(text): + suggestions = suggest_type(text, "INSERT INTO abc (") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_insert_into_lparen_partial_text_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_insert_into_lparen_comma_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_partially_typed_col_name_suggests_col_names(): + suggestions = suggest_type( + "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n" + ) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert set(suggestions) == { + Column(table_refs=((None, "tabl", None, False),)), + Table(schema="tabl"), + View(schema="tabl"), + Function(schema="tabl"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT t1. FROM tabl1 t1", + "SELECT t1. FROM tabl1 t1, tabl2 t2", + 'SELECT t1. FROM "tabl1" t1', + 'SELECT t1. FROM "tabl1" t1, "tabl2" t2', + ], +) +def test_dot_suggests_cols_of_an_alias(sql): + suggestions = suggest_type(sql, "SELECT t1.") + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM tabl1 t1 WHERE t1.", + "SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.", + 'SELECT * FROM "tabl1" t1 WHERE t1.', + 'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.', + ], +) +def test_dot_suggests_cols_of_an_alias_where(sql): + suggestions = suggest_type(sql, sql) + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } + + +def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): + suggestions = suggest_type( + "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2." + ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl2", "t2", False),)), + Table(schema="t2"), + View(schema="t2"), + Function(schema="t2"), + } + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + ], +) +def test_sub_select_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == (Keyword(),) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) +def test_sub_select_partial_text_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == (Keyword(),) + + +def test_outer_table_reference_in_exists_subquery_suggests_columns(): + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." + suggestions = suggest_type(q, q) + assert set(suggestions) == { + Column(table_refs=((None, "foo", "f", False),)), + Table(schema="f"), + View(schema="f"), + Function(schema="f"), + } + + +@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "]) +def test_sub_select_table_name_completion(expression): + suggestion = suggest_type(expression, expression) + assert set(suggestion) == {FromClauseItem(schema=None), Schema()} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) +def test_sub_select_table_name_completion_with_outer_table(expression): + suggestion = suggest_type(expression, expression) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} + + +def test_sub_select_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT " + ) + assert set(suggestions) == { + Column(table_refs=((None, "abc", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +@pytest.mark.xfail +def test_sub_select_multiple_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, " + ) + assert set(suggestions) == cols_etc("abc") + + +def test_sub_select_dot_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t." + ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl", "t", False),)), + Table(schema="t"), + View(schema="t"), + Function(schema="t"), + } + + +@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER")) +@pytest.mark.parametrize("tbl_alias", ("", "foo")) +def test_join_suggests_tables_and_schemas(tbl_alias, join_type): + text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN " + suggestion = suggest_type(text, text) + tbls = tuple([(None, "abc", tbl_alias or None, False)]) + assert set(suggestion) == { + FromClauseItem(schema=None, table_refs=tbls), + Schema(), + Join(tbls, None), + } + + +def test_left_join_with_comma(): + text = "select * from foo f left join bar b," + suggestions = suggest_type(text, text) + # tbls should also include (None, 'bar', 'b', False) + # but there's a bug with commas + tbls = tuple([(None, "foo", "f", False)]) + assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) +def test_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", "a", False), (None, "def", "d", False)) + assert set(suggestions) == { + Column(table_refs=((None, "abc", "a", False),)), + Table(schema="a"), + View(schema="a"), + Function(schema="a"), + JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) +def test_join_alias_dot_suggests_cols2(sql): + suggestion = suggest_type(sql, sql) + assert set(suggestion) == { + Column(table_refs=((None, "def", "d", False),)), + Table(schema="d"), + View(schema="d"), + Function(schema="d"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + """select a.x, b.y +from abc a +join bcd b on +""", + """select a.x, b.y +from abc a +join bcd b +on """, + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + ], +) +def test_on_suggests_aliases_and_join_conditions(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", "a", False), (None, "bcd", "b", False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("a", "b")), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) +def test_on_suggests_tables_and_join_conditions(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) +def test_on_suggests_aliases_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == (Alias(aliases=("a", "b")),) + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) +def test_on_suggests_tables_and_join_conditions_right_side(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } + + +@pytest.mark.parametrize( + "text", + ( + "select * from abc inner join def using (", + "select * from abc inner join def using (col1, ", + "insert into hij select * from abc inner join def using (", + """insert into hij(x, y, z) + select * from abc inner join def using (col1, """, + """insert into hij (a,b,c) + select * from abc inner join def using (col1, """, + ), +) +def test_join_using_suggests_common_columns(text): + tables = ((None, "abc", None, False), (None, "def", None, False)) + assert set(suggest_type(text, text)) == { + Column(table_refs=tables, require_last_table=True) + } + + +def test_suggest_columns_after_multiple_joins(): + sql = """select * from t1 + inner join t2 ON + t1.id = t2.t1_id + inner join t3 ON + t2.id = t3.""" + suggestions = suggest_type(sql, sql) + assert Column(table_refs=((None, "t3", None, False),)) in set(suggestions) + + +def test_2_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ", "select * from a; select * from " + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type( + "select * from a; select from b", "select * from a; select " + ) + assert set(suggestions) == { + Column(table_refs=((None, "b", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + # Should work even if first statement is invalid + suggestions = suggest_type( + "select * from; select * from ", "select * from; select * from " + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +def test_2_statements_1st_current(): + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type("select from a; select * from b", "select ") + assert set(suggestions) == cols_etc("a", last_keyword="SELECT") + + +def test_3_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ; select * from c", + "select * from a; select * from ", + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type( + "select * from a; select from b; select * from c", "select * from a; select " + ) + assert set(suggestions) == cols_etc("b", last_keyword="SELECT") + + +@pytest.mark.parametrize( + "text", + [ + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; + """, + """create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + """, + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT 3 FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + """, + """ +SELECT * FROM baz; +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 3 FROM bar; +SELECT FROM foo; +$func$ +SELECT * FROM qux; + """, + ], +) +def test_statements_in_function_body(text): + suggestions = suggest_type(text, text[: text.find(" ") + 1]) + assert set(suggestions) == { + Column(table_refs=((None, "foo", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +functions = [ + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT 1 FROM foo; +SELECT 2 FROM bar; +$$ language sql; + """, + """ +create function func2(int, varchar) +RETURNS text +language sql AS +' +SELECT 2 FROM bar; +SELECT 1 FROM foo; +'; + """, +] + + +@pytest.mark.parametrize("text", functions) +def test_statements_with_cursor_after_function_body(text): + suggestions = suggest_type(text, text[: text.find("; ") + 1]) + assert set(suggestions) == {Keyword(), Special()} + + +@pytest.mark.parametrize("text", functions) +def test_statements_with_cursor_before_function_body(text): + suggestions = suggest_type(text, "") + assert set(suggestions) == {Keyword(), Special()} + + +def test_create_db_with_template(): + suggestions = suggest_type( + "create database foo with template ", "create database foo with template " + ) + + assert set(suggestions) == {Database()} + + +@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n")) +def test_specials_included_for_initial_completion(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert set(suggestions) == {Keyword(), Special()} + + +def test_drop_schema_qualified_table_suggests_only_tables(): + text = "DROP TABLE schema_name.table_name" + suggestions = suggest_type(text, text) + assert suggestions == (Table(schema="schema_name"),) + + +@pytest.mark.parametrize("text", (",", " ,", "sel ,")) +def test_handle_pre_completion_comma_gracefully(text): + suggestions = suggest_type(text, text) + + assert iter(suggestions) + + +def test_drop_schema_suggests_schemas(): + sql = "DROP SCHEMA " + assert suggest_type(sql, sql) == (Schema(),) + + +@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"]) +def test_cast_operator_suggests_types(text): + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."] +) +def test_cast_operator_suggests_schema_qualified_types(text): + assert set(suggest_type(text, text)) == { + Datatype(schema="bar"), + Table(schema="bar"), + } + + +def test_alter_column_type_suggests_types(): + q = "ALTER TABLE foo ALTER COLUMN bar TYPE " + assert set(suggest_type(q, q)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", + [ + "CREATE TABLE foo (bar ", + "CREATE TABLE foo (bar DOU", + "CREATE TABLE foo (bar INT, baz ", + "CREATE TABLE foo (bar INT, baz TEXT, qux ", + "CREATE FUNCTION foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "SELECT * FROM foo() AS bar (baz ", + "SELECT * FROM foo() AS bar (baz INT, qux ", + # make sure this doesn't trigger special completion + "CREATE TABLE foo (dt d", + ], +) +def test_identifier_suggests_types_in_parentheses(text): + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", + [ + "SELECT foo ", + "SELECT foo FROM bar ", + "SELECT foo AS bar ", + "SELECT foo bar ", + "SELECT * FROM foo AS bar ", + "SELECT * FROM foo bar ", + "SELECT foo FROM (SELECT bar ", + ], +) +def test_alias_suggests_keywords(text): + suggestions = suggest_type(text, text) + assert suggestions == (Keyword(),) + + +def test_invalid_sql(): + # issue 317 + text = "selt *" + suggestions = suggest_type(text, text) + assert suggestions == (Keyword(),) + + +@pytest.mark.parametrize( + "text", + ["SELECT * FROM foo where created > now() - ", "select * from foo where bar "], +) +def test_suggest_where_keyword(text): + # https://github.com/dbcli/mycli/issues/135 + suggestions = suggest_type(text, text) + assert set(suggestions) == cols_etc("foo", last_keyword="WHERE") + + +@pytest.mark.parametrize( + "text, before, expected", + [ + ( + "\\ns abc SELECT ", + "SELECT ", + [ + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ], + ), + ("\\ns abc SELECT foo ", "SELECT foo ", (Keyword(),)), + ( + "\\ns abc SELECT t1. FROM tabl1 t1", + "SELECT t1.", + [ + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + ], + ), + ], +) +def test_named_query_completion(text, before, expected): + suggestions = suggest_type(text, before) + assert set(expected) == set(suggestions) + + +def test_select_suggests_fields_from_function(): + suggestions = suggest_type("SELECT FROM func()", "SELECT ") + assert set(suggestions) == cols_etc("func", is_function=True, last_keyword="SELECT") + + +@pytest.mark.parametrize("sql", ["("]) +def test_leading_parenthesis(sql): + # No assertion for now; just make sure it doesn't crash + suggest_type(sql, sql) + + +@pytest.mark.parametrize("sql", ['select * from "', 'select * from "foo']) +def test_ignore_leading_double_quotes(sql): + suggestions = suggest_type(sql, sql) + assert FromClauseItem(schema=None) in set(suggestions) + + +@pytest.mark.parametrize( + "sql", + [ + "ALTER TABLE foo ALTER COLUMN ", + "ALTER TABLE foo ALTER COLUMN bar", + "ALTER TABLE foo DROP COLUMN ", + "ALTER TABLE foo DROP COLUMN bar", + ], +) +def test_column_keyword_suggests_columns(sql): + suggestions = suggest_type(sql, sql) + assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))} + + +def test_handle_unrecognized_kw_generously(): + sql = "SELECT * FROM sessions WHERE session = 1 AND " + suggestions = suggest_type(sql, sql) + expected = Column(table_refs=((None, "sessions", None, False),), qualifiable=True) + + assert expected in set(suggestions) + + +@pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "]) +def test_keyword_after_alter(sql): + assert Keyword("ALTER") in set(suggest_type(sql, sql)) diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py new file mode 100644 index 0000000..ae865f4 --- /dev/null +++ b/tests/test_ssh_tunnel.py @@ -0,0 +1,188 @@ +import os +from unittest.mock import patch, MagicMock, ANY + +import pytest +from configobj import ConfigObj +from click.testing import CliRunner +from sshtunnel import SSHTunnelForwarder + +from pgcli.main import cli, PGCli +from pgcli.pgexecute import PGExecute + + +@pytest.fixture +def mock_ssh_tunnel_forwarder() -> MagicMock: + mock_ssh_tunnel_forwarder = MagicMock( + SSHTunnelForwarder, local_bind_ports=[1111], autospec=True + ) + with patch( + "pgcli.main.sshtunnel.SSHTunnelForwarder", + return_value=mock_ssh_tunnel_forwarder, + ) as mock: + yield mock + + +@pytest.fixture +def mock_pgexecute() -> MagicMock: + with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute: + yield mock_pgexecute + + +def test_ssh_tunnel( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + # Test with just a host + tunnel_url = "some.host" + db_params = { + "database": "dbname", + "host": "db.host", + "user": "db_user", + "passwd": "db_passwd", + } + expected_tunnel_params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (db_params["host"], 5432), + "ssh_address_or_host": (tunnel_url, 22), + "logger": ANY, + } + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with a full url and with a specific db port + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "some.other.host" + tunnel_port = 1022 + tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + db_params["port"] = 1234 + + expected_tunnel_params["remote_bind_address"] = ( + db_params["host"], + db_params["port"], + ) + expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port) + expected_tunnel_params["ssh_username"] = tunnel_user + expected_tunnel_params["ssh_password"] = tunnel_passwd + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with DSN + dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host={db_params['host']} port={db_params['port']}" + ) + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(dsn=dsn) + + expected_dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" + ) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert expected_dsn in call_args + + +def test_cli_with_tunnel() -> None: + runner = CliRunner() + tunnel_url = "mytunnel" + with patch.object( + PGCli, "__init__", autospec=True, return_value=None + ) as mock_pgcli: + runner.invoke(cli, ["--ssh-tunnel", tunnel_url]) + mock_pgcli.assert_called_once() + call_args, call_kwargs = mock_pgcli.call_args + assert call_kwargs["ssh_tunnel_url"] == tunnel_url + + +def test_config( + tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + pgclirc = str(tmpdir.join("rcfile")) + + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "tunnel.host" + tunnel_port = 1022 + tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + + tunnel2_url = "tunnel2.host" + + config = ConfigObj() + config.filename = pgclirc + config["ssh tunnels"] = {} + config["ssh tunnels"][r"\.com$"] = tunnel_url + config["ssh tunnels"][r"^hello-"] = tunnel2_url + config.write() + + # Unmatched host + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="unmatched.host") + mock_ssh_tunnel_forwarder.assert_not_called() + + # Host matching first tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="matched.host.com") + mock_ssh_tunnel_forwarder.assert_called_once() + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching second tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22) + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching both tunnels (will use the first one matched) + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched.com") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..67d769f --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,92 @@ +import pytest +import psycopg +from pgcli.main import format_output, OutputSettings +from os import getenv + +POSTGRES_USER = getenv("PGUSER", "postgres") +POSTGRES_HOST = getenv("PGHOST", "localhost") +POSTGRES_PORT = getenv("PGPORT", 5432) +POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres") + + +def db_connection(dbname=None): + conn = psycopg.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + dbname=dbname, + ) + conn.autocommit = True + return conn + + +try: + conn = db_connection() + CAN_CONNECT_TO_DB = True + SERVER_VERSION = conn.info.parameter_status("server_version") + JSON_AVAILABLE = True + JSONB_AVAILABLE = True +except Exception as x: + CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False + SERVER_VERSION = 0 + + +dbtest = pytest.mark.skipif( + not CAN_CONNECT_TO_DB, + reason="Need a postgres instance at localhost accessible by user 'postgres'", +) + + +requires_json = pytest.mark.skipif( + not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined" +) + + +requires_jsonb = pytest.mark.skipif( + not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined" +) + + +def create_db(dbname): + with db_connection().cursor() as cur: + try: + cur.execute("""CREATE DATABASE _test_db""") + except: + pass + + +def drop_tables(conn): + with conn.cursor() as cur: + cur.execute( + """ + DROP SCHEMA public CASCADE; + CREATE SCHEMA public; + DROP SCHEMA IF EXISTS schema1 CASCADE; + DROP SCHEMA IF EXISTS schema2 CASCADE""" + ) + + +def run( + executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None +): + "Return string output for the sql to be run" + + results = executor.run(sql, pgspecial, exception_formatter) + formatted = [] + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded + ) + for title, rows, headers, status, sql, success, is_special in results: + formatted.extend(format_output(title, rows, headers, status, settings)) + if join: + formatted = "\n".join(formatted) + + return formatted + + +def completions_to_set(completions): + return { + (completion.display_text, completion.display_meta_text) + for completion in completions + } diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..7d33698 --- /dev/null +++ b/tox.ini @@ -0,0 +1,14 @@ +[tox] +envlist = py38, py39, py310, py311, py312 +[testenv] +deps = pytest>=2.7.0,<=3.0.7 + mock>=1.0.1 + behave>=1.2.4 + pexpect==3.3 + sshtunnel>=0.4.0 +commands = py.test + behave tests/features +passenv = PGHOST + PGPORT + PGUSER + PGPASSWORD -- cgit v1.2.3