summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.coveragerc3
-rw-r--r--.editorconfig15
-rw-r--r--.git-blame-ignore-revs0
-rw-r--r--.github/ISSUE_TEMPLATE.md9
-rw-r--r--.github/PULL_REQUEST_TEMPLATE.md12
-rw-r--r--.github/workflows/ci.yml102
-rw-r--r--.github/workflows/codeql.yml41
-rw-r--r--.gitignore74
-rw-r--r--.pre-commit-config.yaml5
-rw-r--r--AUTHORS136
-rw-r--r--DEVELOP.rst220
-rw-r--r--Dockerfile6
-rw-r--r--LICENSE.txt26
-rw-r--r--MANIFEST.in2
-rw-r--r--README.rst392
-rw-r--r--RELEASES.md24
-rw-r--r--TODO12
-rw-r--r--Vagrantfile138
-rw-r--r--changelog.rst1203
-rw-r--r--pgcli-completion.bash61
-rw-r--r--pgcli/__init__.py1
-rw-r--r--pgcli/__main__.py9
-rw-r--r--pgcli/auth.py60
-rw-r--r--pgcli/completion_refresher.py152
-rw-r--r--pgcli/config.py99
-rw-r--r--pgcli/explain_output_formatter.py19
-rw-r--r--pgcli/key_bindings.py133
-rw-r--r--pgcli/magic.py71
-rw-r--r--pgcli/main.py1728
-rw-r--r--pgcli/packages/__init__.py0
-rw-r--r--pgcli/packages/formatter/__init__.py1
-rw-r--r--pgcli/packages/formatter/sqlformatter.py74
-rw-r--r--pgcli/packages/parseutils/__init__.py58
-rw-r--r--pgcli/packages/parseutils/ctes.py141
-rw-r--r--pgcli/packages/parseutils/meta.py170
-rw-r--r--pgcli/packages/parseutils/tables.py165
-rw-r--r--pgcli/packages/parseutils/utils.py140
-rw-r--r--pgcli/packages/pgliterals/__init__.py0
-rw-r--r--pgcli/packages/pgliterals/main.py15
-rw-r--r--pgcli/packages/pgliterals/pgliterals.json630
-rw-r--r--pgcli/packages/prioritization.py51
-rw-r--r--pgcli/packages/prompt_utils.py37
-rw-r--r--pgcli/packages/sqlcompletion.py605
-rw-r--r--pgcli/pgbuffer.py61
-rw-r--r--pgcli/pgclirc236
-rw-r--r--pgcli/pgcompleter.py1072
-rw-r--r--pgcli/pgexecute.py868
-rw-r--r--pgcli/pgstyle.py116
-rw-r--r--pgcli/pgtoolbar.py71
-rw-r--r--pgcli/pyev.py439
-rw-r--r--post-install4
-rw-r--r--post-remove4
-rw-r--r--pylintrc2
-rw-r--r--pyproject.toml21
-rw-r--r--release.py135
-rw-r--r--requirements-dev.txt13
-rw-r--r--sanity_checks.txt37
-rw-r--r--screenshots/image01.pngbin0 -> 82111 bytes
-rw-r--r--screenshots/image02.pngbin0 -> 11767 bytes
-rw-r--r--screenshots/kharkiv-destroyed.jpgbin0 -> 187337 bytes
-rw-r--r--screenshots/pgcli.gifbin0 -> 238421 bytes
-rw-r--r--setup.py76
-rw-r--r--tests/conftest.py52
-rw-r--r--tests/features/__init__.py0
-rw-r--r--tests/features/auto_vertical.feature12
-rw-r--r--tests/features/basic_commands.feature81
-rw-r--r--tests/features/crud_database.feature17
-rw-r--r--tests/features/crud_table.feature45
-rw-r--r--tests/features/db_utils.py87
-rw-r--r--tests/features/environment.py227
-rw-r--r--tests/features/expanded.feature29
-rw-r--r--tests/features/fixture_data/help.txt25
-rw-r--r--tests/features/fixture_data/help_commands.txt64
-rw-r--r--tests/features/fixture_data/mock_pg_service.conf4
-rw-r--r--tests/features/fixture_utils.py28
-rw-r--r--tests/features/iocommands.feature17
-rw-r--r--tests/features/named_queries.feature10
-rw-r--r--tests/features/pgbouncer.feature12
-rw-r--r--tests/features/specials.feature6
-rw-r--r--tests/features/steps/__init__.py0
-rw-r--r--tests/features/steps/auto_vertical.py99
-rw-r--r--tests/features/steps/basic_commands.py231
-rw-r--r--tests/features/steps/crud_database.py93
-rw-r--r--tests/features/steps/crud_table.py185
-rw-r--r--tests/features/steps/expanded.py70
-rw-r--r--tests/features/steps/iocommands.py80
-rw-r--r--tests/features/steps/named_queries.py57
-rw-r--r--tests/features/steps/pgbouncer.py22
-rw-r--r--tests/features/steps/specials.py31
-rw-r--r--tests/features/steps/wrappers.py71
-rw-r--r--tests/features/wrappager.py16
-rw-r--r--tests/formatter/__init__.py1
-rw-r--r--tests/formatter/test_sqlformatter.py111
-rw-r--r--tests/metadata.py255
-rw-r--r--tests/parseutils/test_ctes.py137
-rw-r--r--tests/parseutils/test_function_metadata.py19
-rw-r--r--tests/parseutils/test_parseutils.py310
-rw-r--r--tests/pytest.ini2
-rw-r--r--tests/test_auth.py40
-rw-r--r--tests/test_completion_refresher.py95
-rw-r--r--tests/test_config.py43
-rw-r--r--tests/test_fuzzy_completion.py87
-rw-r--r--tests/test_main.py490
-rw-r--r--tests/test_naive_completion.py133
-rw-r--r--tests/test_pgcompleter.py76
-rw-r--r--tests/test_pgexecute.py773
-rw-r--r--tests/test_pgspecial.py78
-rw-r--r--tests/test_prioritization.py20
-rw-r--r--tests/test_prompt_utils.py17
-rw-r--r--tests/test_rowlimit.py79
-rw-r--r--tests/test_smart_completion_multiple_schemata.py757
-rw-r--r--tests/test_smart_completion_public_schema_only.py1112
-rw-r--r--tests/test_sqlcompletion.py964
-rw-r--r--tests/test_ssh_tunnel.py188
-rw-r--r--tests/utils.py92
-rw-r--r--tox.ini14
116 files changed, 17559 insertions, 0 deletions
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..b2713c7
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,3 @@
+[run]
+parallel=True
+source=pgcli
diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 0000000..bacb65c
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,15 @@
+# editorconfig.org
+# Get your text editor plugin at:
+# http://editorconfig.org/#download
+root = true
+
+[*]
+charset = utf-8
+end_of_line = lf
+indent_size = 4
+indent_style = space
+insert_final_newline = true
+trim_trailing_whitespace = true
+
+[travis.yml]
+indent_size = 2
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/.git-blame-ignore-revs
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
new file mode 100644
index 0000000..b5cdbec
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE.md
@@ -0,0 +1,9 @@
+## Description
+<!--- Describe your problem as fully as you can. -->
+
+## Your environment
+<!-- This gives us some more context to work with. -->
+
+- [ ] Please provide your OS and version information.
+- [ ] Please provide your CLI version.
+- [ ] What is the output of ``pip freeze`` command.
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000..35e8486
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,12 @@
+## Description
+<!--- Describe your changes in detail. -->
+
+
+
+## Checklist
+<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. -->
+- [ ] I've added this contribution to the `changelog.rst`.
+- [ ] I've added my name to the `AUTHORS` file (or it's already there).
+<!-- We would appreciate if you comply with our code style guidelines. -->
+- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code.
+- [x] Please squash merge this pull request (uncheck if you'd like us to merge as multiple commits)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..68a69ac
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,102 @@
+name: pgcli
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ paths-ignore:
+ - '**.rst'
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+
+ services:
+ postgres:
+ image: postgres:9.6
+ env:
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: postgres
+ ports:
+ - 5432:5432
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install pgbouncer
+ run: |
+ sudo apt install pgbouncer -y
+
+ sudo chmod 666 /etc/pgbouncer/*.*
+
+ cat <<EOF > /etc/pgbouncer/userlist.txt
+ "postgres" "postgres"
+ EOF
+
+ cat <<EOF > /etc/pgbouncer/pgbouncer.ini
+ [databases]
+ * = host=localhost port=5432
+ [pgbouncer]
+ listen_port = 6432
+ listen_addr = localhost
+ auth_type = trust
+ auth_file = /etc/pgbouncer/userlist.txt
+ logfile = pgbouncer.log
+ pidfile = pgbouncer.pid
+ admin_users = postgres
+ EOF
+
+ sudo systemctl stop pgbouncer
+
+ pgbouncer -d /etc/pgbouncer/pgbouncer.ini
+
+ psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help'
+
+ - name: Install beta version of pendulum
+ run: pip install pendulum==3.0.0b1
+ if: matrix.python-version == '3.12'
+
+ - name: Install requirements
+ run: |
+ pip install -U pip setuptools
+ pip install --no-cache-dir ".[sshtunnel]"
+ pip install -r requirements-dev.txt
+ pip install keyrings.alt>=3.1
+
+ - name: Run unit tests
+ run: coverage run --source pgcli -m pytest
+
+ - name: Run integration tests
+ env:
+ PGUSER: postgres
+ PGPASSWORD: postgres
+
+ run: behave tests/features --no-capture
+
+ - name: Check changelog for ReST compliance
+ run: rst2html.py --halt=warning changelog.rst >/dev/null
+
+ - name: Run Black
+ run: black --check .
+ if: matrix.python-version == '3.8'
+
+ - name: Coverage
+ run: |
+ coverage combine
+ coverage report
+ codecov
diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
new file mode 100644
index 0000000..c9232c7
--- /dev/null
+++ b/.github/workflows/codeql.yml
@@ -0,0 +1,41 @@
+name: "CodeQL"
+
+on:
+ push:
+ branches: [ "main" ]
+ pull_request:
+ branches: [ "main" ]
+ schedule:
+ - cron: "29 13 * * 1"
+
+jobs:
+ analyze:
+ name: Analyze
+ runs-on: ubuntu-latest
+ permissions:
+ actions: read
+ contents: read
+ security-events: write
+
+ strategy:
+ fail-fast: false
+ matrix:
+ language: [ python ]
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Initialize CodeQL
+ uses: github/codeql-action/init@v2
+ with:
+ languages: ${{ matrix.language }}
+ queries: +security-and-quality
+
+ - name: Autobuild
+ uses: github/codeql-action/autobuild@v2
+
+ - name: Perform CodeQL Analysis
+ uses: github/codeql-action/analyze@v2
+ with:
+ category: "/language:${{ matrix.language }}"
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b993cb9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,74 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+pyvenv/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+.pytest_cache
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# PyCharm
+.idea/
+*.iml
+
+# Vagrant
+.vagrant/
+
+# Generated Packages
+*.deb
+*.rpm
+
+.vscode/
+venv/
+
+.ropeproject/
+
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..8462cc2
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,5 @@
+repos:
+- repo: https://github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
diff --git a/AUTHORS b/AUTHORS
new file mode 100644
index 0000000..5eff7db
--- /dev/null
+++ b/AUTHORS
@@ -0,0 +1,136 @@
+Many thanks to the following contributors.
+
+Project Lead:
+-------------
+ * Irina Truong
+
+Core Devs:
+----------
+ * Amjith Ramanujam
+ * Darik Gamble
+ * Stuart Quin
+ * Joakim Koljonen
+ * Daniel Rocco
+ * Karl-Aksel Puulmann
+ * Dick Marinus
+
+Contributors:
+-------------
+ * Brett
+ * Étienne BERSAC (bersace)
+ * Daniel Schwarz
+ * inkn
+ * Jonathan Slenders
+ * xalley
+ * TamasNo1
+ * François Pietka
+ * Michael Kaminsky
+ * Alexander Kukushkin
+ * Ludovic Gasc (GMLudo)
+ * Marc Abramowitz
+ * Nick Hahner
+ * Jay Zeng
+ * Dimitar Roustchev
+ * Dhaivat Pandit
+ * Matheus Rosa
+ * Ali Kargın
+ * Nathan Jhaveri
+ * David Celis
+ * Sven-Hendrik Haase
+ * Çağatay Yüksel
+ * Tiago Ribeiro
+ * Vignesh Anand
+ * Charlie Arnold
+ * dwalmsley
+ * Artur Dryomov
+ * rrampage
+ * while0pass
+ * Eric Workman
+ * xa
+ * Hans Roman
+ * Guewen Baconnier
+ * Dionysis Grigoropoulos
+ * Jacob Magnusson
+ * Johannes Hoff
+ * vinotheassassin
+ * Jacek Wielemborek
+ * Fabien Meghazi
+ * Manuel Barkhau
+ * Sergii V
+ * Emanuele Gaifas
+ * Owen Stephens
+ * Russell Davies
+ * AlexTes
+ * Hraban Luyat
+ * Jackson Popkin
+ * Gustavo Castro
+ * Alexander Schmolck
+ * Donnell Muse
+ * Andrew Speed
+ * Dmitry B
+ * Isank
+ * Marcin Sztolcman
+ * Bojan Delić
+ * Chris Vaughn
+ * Frederic Aoustin
+ * Pierre Giraud
+ * Andrew Kuchling
+ * Dan Clark
+ * Catherine Devlin
+ * Jason Ribeiro
+ * Rishi Ramraj
+ * Matthieu Guilbert
+ * Alexandr Korsak
+ * Saif Hakim
+ * Artur Balabanov
+ * Kenny Do
+ * Max Rothman
+ * Daniel Egger
+ * Ignacio Campabadal
+ * Mikhail Elovskikh (wronglink)
+ * Marcin Cieślak (saper)
+ * easteregg (verfriemelt-dot-org)
+ * Scott Brenstuhl (808sAndBR)
+ * Nathan Verzemnieks
+ * raylu
+ * Zhaolong Zhu
+ * Zane C. Bowers-Hadley
+ * Telmo "Trooper" (telmotrooper)
+ * Alexander Zawadzki
+ * Pablo A. Bianchi (pabloab)
+ * Sebastian Janko (sebojanko)
+ * Pedro Ferrari (petobens)
+ * Martin Matejek (mmtj)
+ * Jonas Jelten
+ * BrownShibaDog
+ * George Thomas(thegeorgeous)
+ * Yoni Nakache(lazydba247)
+ * Gantsev Denis
+ * Stephano Paraskeva
+ * Panos Mavrogiorgos (pmav99)
+ * Igor Kim (igorkim)
+ * Anthony DeBarros (anthonydb)
+ * Seungyong Kwak (GUIEEN)
+ * Tom Caruso (tomplex)
+ * Jan Brun Rasmussen (janbrunrasmussen)
+ * Kevin Marsh (kevinmarsh)
+ * Eero Ruohola (ruohola)
+ * Miroslav Šedivý (eumiro)
+ * Eric R Young (ERYoung11)
+ * Paweł Sacawa (psacawa)
+ * Bruno Inec (sweenu)
+ * Daniele Varrazzo
+ * Daniel Kukula (dkuku)
+ * Kian-Meng Ang (kianmeng)
+ * Liu Zhao (astroshot)
+ * Rigo Neri (rigoneri)
+ * Anna Glasgall (annathyst)
+ * Andy Schoenberger (andyscho)
+ * Damien Baty (dbaty)
+ * blag
+ * Rob Berry (rob-b)
+ * Sharon Yogev (sharonyogev)
+
+Creator:
+--------
+Amjith Ramanujam
diff --git a/DEVELOP.rst b/DEVELOP.rst
new file mode 100644
index 0000000..aed2cf8
--- /dev/null
+++ b/DEVELOP.rst
@@ -0,0 +1,220 @@
+Development Guide
+-----------------
+This is a guide for developers who would like to contribute to this project.
+
+GitHub Workflow
+---------------
+
+If you're interested in contributing to pgcli, first of all my heart felt
+thanks. `Fork the project <https://github.com/dbcli/pgcli>`_ on github. Then
+clone your fork into your computer (``git clone <url-for-your-fork>``). Make
+the changes and create the commits in your local machine. Then push those
+changes to your fork. Then click on the pull request icon on github and create
+a new pull request. Add a description about the change and send it along. I
+promise to review the pull request in a reasonable window of time and get back
+to you.
+
+In order to keep your fork up to date with any changes from mainline, add a new
+git remote to your local copy called 'upstream' and point it to the main pgcli
+repo.
+
+::
+
+ $ git remote add upstream git@github.com:dbcli/pgcli.git
+
+Once the 'upstream' end point is added you can then periodically do a ``git
+pull upstream master`` to update your local copy and then do a ``git push
+origin master`` to keep your own fork up to date.
+
+Check Github's `Understanding the GitHub flow guide
+<https://guides.github.com/introduction/flow/>`_ for a more detailed
+explanation of this process.
+
+Local Setup
+-----------
+
+The installation instructions in the README file are intended for users of
+pgcli. If you're developing pgcli, you'll need to install it in a slightly
+different way so you can see the effects of your changes right away without
+having to go through the install cycle every time you change the code.
+
+It is highly recommended to use virtualenv for development. If you don't know
+what a virtualenv is, `this guide <http://docs.python-guide.org/en/latest/dev/virtualenvs/#virtual-environments>`_
+will help you get started.
+
+Create a virtualenv (let's call it pgcli-dev). Activate it:
+
+::
+
+ source ./pgcli-dev/bin/activate
+
+ or
+
+ .\pgcli-dev\scripts\activate (for Windows)
+
+Once the virtualenv is activated, `cd` into the local clone of pgcli folder
+and install pgcli using pip as follows:
+
+::
+
+ $ pip install --editable .
+
+ or
+
+ $ pip install -e .
+
+This will install the necessary dependencies as well as install pgcli from the
+working folder into the virtualenv. By installing it using `pip install -e`
+we've linked the pgcli installation with the working copy. Any changes made
+to the code are immediately available in the installed version of pgcli. This
+makes it easy to change something in the code, launch pgcli and check the
+effects of your changes.
+
+Adding PostgreSQL Special (Meta) Commands
+-----------------------------------------
+
+If you want to work on adding new meta-commands (such as `\dp`, `\ds`, `dy`),
+you need to contribute to `pgspecial <https://github.com/dbcli/pgspecial/>`_
+project.
+
+Visual Studio Code Debugging
+-----------------------------
+To set up Visual Studio Code to debug pgcli requires a launch.json file.
+
+Within the project, create a file: .vscode\\launch.json like below.
+
+::
+
+ {
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Python: Module",
+ "type": "python",
+ "request": "launch",
+ "module": "pgcli.main",
+ "justMyCode": false,
+ "console": "externalTerminal",
+ "env": {
+ "PGUSER": "postgres",
+ "PGPASS": "password",
+ "PGHOST": "localhost",
+ "PGPORT": "5432"
+ }
+ }
+ ]
+ }
+
+Building RPM and DEB packages
+-----------------------------
+
+You will need Vagrant 1.7.2 or higher. In the project root there is a
+Vagrantfile that is setup to do multi-vm provisioning. If you're setting things
+up for the first time, then do:
+
+::
+
+ $ version=x.y.z vagrant up debian
+ $ version=x.y.z vagrant up centos
+
+If you already have those VMs setup and you're merely creating a new version of
+DEB or RPM package, then you can do:
+
+::
+
+ $ version=x.y.z vagrant provision
+
+That will create a .deb file and a .rpm file.
+
+The deb package can be installed as follows:
+
+::
+
+ $ sudo dpkg -i pgcli*.deb # if dependencies are available.
+
+ or
+
+ $ sudo apt-get install -f pgcli*.deb # if dependencies are not available.
+
+
+The rpm package can be installed as follows:
+
+::
+
+ $ sudo yum install pgcli*.rpm
+
+Running the integration tests
+-----------------------------
+
+Integration tests use `behave package <https://behave.readthedocs.io/>`_ and
+pytest.
+Configuration settings for this package are provided via a ``behave.ini`` file
+in the ``tests`` directory. An example::
+
+ [behave]
+ stderr_capture = false
+
+ [behave.userdata]
+ pg_test_user = dbuser
+ pg_test_host = db.example.com
+ pg_test_port = 30000
+
+First, install the requirements for testing:
+
+::
+ $ pip install -U pip setuptools
+ $ pip install --no-cache-dir ".[sshtunnel]"
+ $ pip install -r requirements-dev.txt
+
+Ensure that the database user has permissions to create and drop test databases
+by checking your ``pg_hba.conf`` file. The default user should be ``postgres``
+at ``localhost``. Make sure the authentication method is set to ``trust``. If
+you made any changes to your ``pg_hba.conf`` make sure to restart the postgres
+service for the changes to take effect.
+
+::
+
+ # ONLY IF YOU MADE CHANGES TO YOUR pg_hba.conf FILE
+ $ sudo service postgresql restart
+
+After that, tests in the ``/pgcli/tests`` directory can be run with:
+(Note that these ``behave`` tests do not currently work when developing on Windows due to pexpect incompatibility.)
+
+::
+
+ # on directory /pgcli/tests
+ $ behave
+
+And on the ``/pgcli`` directory:
+
+::
+
+ # on directory /pgcli
+ $ py.test
+
+To see stdout/stderr, use the following command:
+
+::
+
+ $ behave --no-capture
+
+Troubleshooting the integration tests
+-------------------------------------
+
+- Make sure postgres instance on localhost is running
+- Check your ``pg_hba.conf`` file to verify local connections are enabled
+- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information.
+- `File an issue <https://github.com/dbcli/pgcli/issues/new>`_.
+
+Coding Style
+------------
+
+``pgcli`` uses `black <https://github.com/ambv/black>`_ to format the source code. Make sure to install black.
+
+Releases
+--------
+
+If you're the person responsible for releasing `pgcli`, `this guide <https://github.com/dbcli/pgcli/blob/main/RELEASES.md>`_ is for you.
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..32d341a
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,6 @@
+FROM python:3.8
+
+COPY . /app
+RUN cd /app && pip install -e .
+
+CMD pgcli
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000..83226b7
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,26 @@
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this
+ list of conditions and the following disclaimer in the documentation and/or
+ other materials provided with the distribution.
+
+* Neither the name of the {organization} nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..1c5f697
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,2 @@
+include LICENSE.txt AUTHORS changelog.rst
+recursive-include tests *.py *.txt *.feature *.ini
diff --git a/README.rst b/README.rst
new file mode 100644
index 0000000..56695cc
--- /dev/null
+++ b/README.rst
@@ -0,0 +1,392 @@
+We stand with Ukraine
+---------------------
+
+Ukrainian people are fighting for their country. A lot of civilians, women and children, are suffering. Hundreds were killed and injured, and thousands were displaced.
+
+This is an image from my home town, Kharkiv. This place is right in the old city center.
+
+.. image:: screenshots/kharkiv-destroyed.jpg
+
+Picture by @fomenko_ph (Telegram).
+
+Please consider donating or volunteering.
+
+* https://bank.gov.ua/en/
+* https://savelife.in.ua/en/donate/
+* https://www.comebackalive.in.ua/donate
+* https://www.globalgiving.org/projects/ukraine-crisis-relief-fund/
+* https://www.savethechildren.org/us/where-we-work/ukraine
+* https://www.facebook.com/donate/1137971146948461/
+* https://donate.wck.org/give/393234#!/donation/checkout
+* https://atlantaforukraine.com/
+
+
+A REPL for Postgres
+-------------------
+
+|Build Status| |CodeCov| |PyPI| |netlify|
+
+This is a postgres client that does auto-completion and syntax highlighting.
+
+Home Page: http://pgcli.com
+
+MySQL Equivalent: http://mycli.net
+
+.. image:: screenshots/pgcli.gif
+.. image:: screenshots/image01.png
+
+Quick Start
+-----------
+
+If you already know how to install python packages, then you can simply do:
+
+::
+
+ $ pip install -U pgcli
+
+ or
+
+ $ sudo apt-get install pgcli # Only on Debian based Linux (e.g. Ubuntu, Mint, etc)
+ $ brew install pgcli # Only on macOS
+
+If you don't know how to install python packages, please check the
+`detailed instructions`_.
+
+.. _`detailed instructions`: https://github.com/dbcli/pgcli#detailed-installation-instructions
+
+Usage
+-----
+
+::
+
+ $ pgcli [database_name]
+
+ or
+
+ $ pgcli postgresql://[user[:password]@][netloc][:port][/dbname][?extra=value[&other=other-value]]
+
+Examples:
+
+::
+
+ $ pgcli local_database
+
+ $ pgcli postgres://amjith:pa$$w0rd@example.com:5432/app_db?sslmode=verify-ca&sslrootcert=/myrootcert
+
+For more details:
+
+::
+
+ $ pgcli --help
+
+ Usage: pgcli [OPTIONS] [DBNAME] [USERNAME]
+
+ Options:
+ -h, --host TEXT Host address of the postgres database.
+ -p, --port INTEGER Port number at which the postgres instance is
+ listening.
+ -U, --username TEXT Username to connect to the postgres database.
+ -u, --user TEXT Username to connect to the postgres database.
+ -W, --password Force password prompt.
+ -w, --no-password Never prompt for password.
+ --single-connection Do not use a separate connection for completions.
+ -v, --version Version of pgcli.
+ -d, --dbname TEXT database name to connect to.
+ --pgclirc FILE Location of pgclirc file.
+ -D, --dsn TEXT Use DSN configured into the [alias_dsn] section
+ of pgclirc file.
+ --list-dsn list of DSN configured into the [alias_dsn]
+ section of pgclirc file.
+ --row-limit INTEGER Set threshold for row limit prompt. Use 0 to
+ disable prompt.
+ --less-chatty Skip intro on startup and goodbye on exit.
+ --prompt TEXT Prompt format (Default: "\u@\h:\d> ").
+ --prompt-dsn TEXT Prompt format for connections using DSN aliases
+ (Default: "\u@\h:\d> ").
+ -l, --list list available databases, then exit.
+ --auto-vertical-output Automatically switch to vertical output mode if
+ the result is wider than the terminal width.
+ --warn [all|moderate|off] Warn before running a destructive query.
+ --help Show this message and exit.
+
+``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``).
+
+The SSL-related environment variables are also supported, so if you need to connect a postgres database via ssl connection, you can set set environment like this:
+
+::
+
+ export PGSSLMODE="verify-full"
+ export PGSSLCERT="/your-path-to-certs/client.crt"
+ export PGSSLKEY="/your-path-to-keys/client.key"
+ export PGSSLROOTCERT="/your-path-to-ca/ca.crt"
+ pgcli -h localhost -p 5432 -U username postgres
+
+.. _environment variables: https://www.postgresql.org/docs/current/libpq-envars.html
+
+Features
+--------
+
+The `pgcli` is written using prompt_toolkit_.
+
+* Auto-completes as you type for SQL keywords as well as tables and
+ columns in the database.
+* Syntax highlighting using Pygments.
+* Smart-completion (enabled by default) will suggest context-sensitive
+ completion.
+
+ - ``SELECT * FROM <tab>`` will only show table names.
+ - ``SELECT * FROM users WHERE <tab>`` will only show column names.
+
+* Primitive support for ``psql`` back-slash commands.
+* Pretty prints tabular data.
+
+.. _prompt_toolkit: https://github.com/jonathanslenders/python-prompt-toolkit
+.. _tabulate: https://pypi.python.org/pypi/tabulate
+
+Config
+------
+A config file is automatically created at ``~/.config/pgcli/config`` at first launch.
+See the file itself for a description of all available options.
+
+Contributions:
+--------------
+
+If you're interested in contributing to this project, first of all I would like
+to extend my heartfelt gratitude. I've written a small doc to describe how to
+get this running in a development setup.
+
+https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst
+
+Please feel free to reach out to us if you need help.
+* Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_
+* Irina, pgcli maintainer: i.chernyavska@gmail.com, Twitter: `@irinatruong <http://twitter.com/irinatruong>`_
+
+Detailed Installation Instructions:
+-----------------------------------
+
+macOS:
+======
+
+The easiest way to install pgcli is using Homebrew.
+
+::
+
+ $ brew install pgcli
+
+Done!
+
+Alternatively, you can install ``pgcli`` as a python package using a package
+manager called called ``pip``. You will need postgres installed on your system
+for this to work.
+
+In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html.
+
+::
+
+ $ which pip
+
+If it is installed then you can do:
+
+::
+
+ $ pip install pgcli
+
+If that fails due to permission issues, you might need to run the command with
+sudo permissions.
+
+::
+
+ $ sudo pip install pgcli
+
+If pip is not installed check if easy_install is available on the system.
+
+::
+
+ $ which easy_install
+
+ $ sudo easy_install pgcli
+
+Linux:
+======
+
+In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html.
+
+Check if pip is already available in your system.
+
+::
+
+ $ which pip
+
+If it doesn't exist, use your linux package manager to install `pip`. This
+might look something like:
+
+::
+
+ $ sudo apt-get install python-pip # Debian, Ubuntu, Mint etc
+
+ or
+
+ $ sudo yum install python-pip # RHEL, Centos, Fedora etc
+
+``pgcli`` requires python-dev, libpq-dev and libevent-dev packages. You can
+install these via your operating system package manager.
+
+
+::
+
+ $ sudo apt-get install python-dev libpq-dev libevent-dev
+
+ or
+
+ $ sudo yum install python-devel postgresql-devel
+
+Then you can install pgcli:
+
+::
+
+ $ sudo pip install pgcli
+
+
+Docker
+======
+
+Pgcli can be run from within Docker. This can be useful to try pgcli without
+installing it, or any dependencies, system-wide.
+
+To build the image:
+
+::
+
+ $ docker build -t pgcli .
+
+To create a container from the image:
+
+::
+
+ $ docker run --rm -ti pgcli pgcli <ARGS>
+
+To access postgresql databases listening on localhost, make sure to run the
+docker in "host net mode". E.g. to access a database called "foo" on the
+postgresql server running on localhost:5432 (the standard port):
+
+::
+
+ $ docker run --rm -ti --net host pgcli pgcli -h localhost foo
+
+To connect to a locally running instance over a unix socket, bind the socket to
+the docker container:
+
+::
+
+ $ docker run --rm -ti -v /var/run/postgres:/var/run/postgres pgcli pgcli foo
+
+
+IPython
+=======
+
+Pgcli can be run from within `IPython <https://ipython.org>`_ console. When working on a query,
+it may be useful to drop into a pgcli session without leaving the IPython console, iterate on a
+query, then quit pgcli to find the query results in your IPython workspace.
+
+Assuming you have IPython installed:
+
+::
+
+ $ pip install ipython-sql
+
+After that, run ipython and load the ``pgcli.magic`` extension:
+
+::
+
+ $ ipython
+
+ In [1]: %load_ext pgcli.magic
+
+
+Connect to a database and construct a query:
+
+::
+
+ In [2]: %pgcli postgres://someone@localhost:5432/world
+ Connected: someone@world
+ someone@localhost:world> select * from city c where countrycode = 'USA' and population > 1000000;
+ +------+--------------+---------------+--------------+--------------+
+ | id | name | countrycode | district | population |
+ |------+--------------+---------------+--------------+--------------|
+ | 3793 | New York | USA | New York | 8008278 |
+ | 3794 | Los Angeles | USA | California | 3694820 |
+ | 3795 | Chicago | USA | Illinois | 2896016 |
+ | 3796 | Houston | USA | Texas | 1953631 |
+ | 3797 | Philadelphia | USA | Pennsylvania | 1517550 |
+ | 3798 | Phoenix | USA | Arizona | 1321045 |
+ | 3799 | San Diego | USA | California | 1223400 |
+ | 3800 | Dallas | USA | Texas | 1188580 |
+ | 3801 | San Antonio | USA | Texas | 1144646 |
+ +------+--------------+---------------+--------------+--------------+
+ SELECT 9
+ Time: 0.003s
+
+
+Exit out of pgcli session with ``Ctrl + D`` and find the query results:
+
+::
+
+ someone@localhost:world>
+ Goodbye!
+ 9 rows affected.
+ Out[2]:
+ [(3793, u'New York', u'USA', u'New York', 8008278),
+ (3794, u'Los Angeles', u'USA', u'California', 3694820),
+ (3795, u'Chicago', u'USA', u'Illinois', 2896016),
+ (3796, u'Houston', u'USA', u'Texas', 1953631),
+ (3797, u'Philadelphia', u'USA', u'Pennsylvania', 1517550),
+ (3798, u'Phoenix', u'USA', u'Arizona', 1321045),
+ (3799, u'San Diego', u'USA', u'California', 1223400),
+ (3800, u'Dallas', u'USA', u'Texas', 1188580),
+ (3801, u'San Antonio', u'USA', u'Texas', 1144646)]
+
+The results are available in special local variable ``_``, and can be assigned to a variable of your
+choice:
+
+::
+
+ In [3]: my_result = _
+
+Pgcli dropped support for Python<3.8 as of 4.0.0. If you need it, install ``pgcli <= 4.0.0``.
+
+Thanks:
+-------
+
+A special thanks to `Jonathan Slenders <https://twitter.com/jonathan_s>`_ for
+creating `Python Prompt Toolkit <http://github.com/jonathanslenders/python-prompt-toolkit>`_,
+which is quite literally the backbone library, that made this app possible.
+Jonathan has also provided valuable feedback and support during the development
+of this app.
+
+`Click <http://click.pocoo.org/>`_ is used for command line option parsing
+and printing error messages.
+
+Thanks to `psycopg <https://www.psycopg.org/>`_ for providing a rock solid
+interface to Postgres database.
+
+Thanks to all the beta testers and contributors for your time and patience. :)
+
+
+.. |Build Status| image:: https://github.com/dbcli/pgcli/actions/workflows/ci.yml/badge.svg?branch=main
+ :target: https://github.com/dbcli/pgcli/actions/workflows/ci.yml
+
+.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
+ :target: https://codecov.io/gh/dbcli/pgcli
+ :alt: Code coverage report
+
+.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/master/landscape.svg?style=flat
+ :target: https://landscape.io/github/dbcli/pgcli/master
+ :alt: Code Health
+
+.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg
+ :target: https://pypi.python.org/pypi/pgcli/
+ :alt: Latest Version
+
+.. |netlify| image:: https://api.netlify.com/api/v1/badges/3a0a14dd-776d-445d-804c-3dd74fe31c4e/deploy-status
+ :target: https://app.netlify.com/sites/pgcli/deploys
+ :alt: Netlify
diff --git a/RELEASES.md b/RELEASES.md
new file mode 100644
index 0000000..526c260
--- /dev/null
+++ b/RELEASES.md
@@ -0,0 +1,24 @@
+Releasing pgcli
+---------------
+
+You have been made the maintainer of `pgcli`? Congratulations! We have a release script to help you:
+
+```sh
+> python release.py --help
+Usage: release.py [options]
+
+Options:
+ -h, --help show this help message and exit
+ -c, --confirm-steps Confirm every step. If the step is not confirmed, it
+ will be skipped.
+ -d, --dry-run Print out, but not actually run any steps.
+```
+
+The script can be run with `-c` to confirm or skip steps. There's also a `--dry-run` option that only prints out the steps.
+
+To release a new version of the package:
+
+* Create and merge a PR to bump the version in the changelog ([example PR](https://github.com/dbcli/pgcli/pull/1325)).
+* Pull `main` and bump the version number inside `pgcli/__init__.py`. Do not check in - the release script will do that.
+* Make sure you have the dev requirements installed: `pip install -r requirements-dev.txt -U --upgrade-strategy only-if-needed`.
+* Finally, run the release script: `python release.py`.
diff --git a/TODO b/TODO
new file mode 100644
index 0000000..2173545
--- /dev/null
+++ b/TODO
@@ -0,0 +1,12 @@
+# vi: ft=vimwiki
+* [ ] Add coverage.
+* [ ] Refactor to sqlcompletion to consume the text from left to right and use a state machine to suggest cols or tables instead of relying on hacks.
+* [ ] Add a few more special commands. (\l pattern, \dp, \ds, \dy, \z etc)
+* [ ] Refactor pgspecial.py to a class.
+* [ ] Show/hide docs for a statement using a keybinding.
+* [ ] Check how to add the name of the table before printing the table.
+* [ ] Add a new trigger for M-/ that does naive completion.
+* [ ] New Feature List - Write the current version to config file. At launch if the version has changed, display the changelog between the two versions.
+* [ ] Add a test for 'select * from custom.abc where custom.abc.' should suggest columns from abc.
+* [ ] pgexecute columns(), tables() etc can be just cursors instead of fetchall()
+* [ ] Add colorschemes in config file.
diff --git a/Vagrantfile b/Vagrantfile
new file mode 100644
index 0000000..297e70a
--- /dev/null
+++ b/Vagrantfile
@@ -0,0 +1,138 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+#
+#
+
+Vagrant.configure(2) do |config|
+
+ config.vm.synced_folder ".", "/pgcli"
+
+ pgcli_version = ENV['version']
+ pgcli_description = "Postgres CLI with autocompletion and syntax highlighting"
+
+ config.vm.define "debian" do |debian|
+ debian.vm.box = "bento/debian-10.8"
+ debian.vm.provision "shell", inline: <<-SHELL
+ echo "-> Building DEB on `lsb_release -d`"
+ sudo apt-get update
+ sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems
+ sudo apt install -y python3-pip
+ sudo pip3 install --no-cache-dir virtualenv virtualenv-tools3
+ sudo apt-get install -y ruby-dev
+ sudo apt-get install -y git
+ sudo apt-get install -y rpm librpmbuild8
+
+ sudo gem install fpm
+
+ echo "-> Cleaning up old workspace"
+ sudo rm -rf build
+ mkdir -p build/usr/share
+ virtualenv build/usr/share/pgcli
+ build/usr/share/pgcli/bin/pip install /pgcli
+
+ echo "-> Cleaning Virtualenv"
+ cd build/usr/share/pgcli
+ virtualenv-tools --update-path /usr/share/pgcli > /dev/null
+ cd /home/vagrant/
+
+ echo "-> Removing compiled files"
+ find build -iname '*.pyc' -delete
+ find build -iname '*.pyo' -delete
+
+ echo "-> Creating PgCLI deb"
+ sudo fpm -t deb -s dir -C build -n pgcli -v #{pgcli_version} \
+ -a all \
+ -d libpq-dev \
+ -d python-dev \
+ -p /pgcli/ \
+ --after-install /pgcli/post-install \
+ --after-remove /pgcli/post-remove \
+ --url https://github.com/dbcli/pgcli \
+ --description "#{pgcli_description}" \
+ --license 'BSD'
+
+ SHELL
+ end
+
+
+# This is considerably more messy than the debian section. I had to go off-standard to update
+# some packages to get this to work.
+
+ config.vm.define "centos" do |centos|
+
+ centos.vm.box = "bento/centos-7.9"
+ centos.vm.box_version = "202012.21.0"
+ centos.vm.provision "shell", inline: <<-SHELL
+ #!/bin/bash
+ echo "-> Building RPM on `hostnamectl | grep "Operating System"`"
+ export PATH=/usr/local/rvm/gems/ruby-2.6.3/bin:/usr/local/rvm/gems/ruby-2.6.3@global/bin:/usr/local/rvm/rubies/ruby-2.6.3/bin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/rvm/bin:/root/bin
+ echo "PATH -> " $PATH
+
+#####
+### get base updates
+
+ sudo yum install -y rpm-build gcc postgresql-devel python-devel python3-pip git python3-devel
+
+######
+### install FPM, which we need to install to get an up-to-date version of ruby, which we need for git
+
+ echo "-> Get FPM installed"
+ # import the necessary GPG keys
+ gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
+ sudo gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
+ # install RVM
+ sudo curl -sSL https://get.rvm.io | sudo bash -s stable
+ sudo usermod -aG rvm vagrant
+ sudo usermod -aG rvm root
+ sudo /usr/local/rvm/bin/rvm alias create default 2.6.3
+ source /etc/profile.d/rvm.sh
+
+ # install a newer version of ruby. centos7 only comes with ruby2.0.0, which isn't good enough for git.
+ sudo yum install -y ruby-devel
+ sudo /usr/local/rvm/bin/rvm install 2.6.3
+
+ #
+ # yes,this gives an error about generating doc but we don't need the doc.
+
+ /usr/local/rvm/gems/ruby-2.6.3/wrappers/gem install fpm
+
+######
+
+ sudo pip3 install virtualenv virtualenv-tools3
+ echo "-> Cleaning up old workspace"
+ rm -rf build
+ mkdir -p build/usr/share
+ virtualenv build/usr/share/pgcli
+ build/usr/share/pgcli/bin/pip install /pgcli
+
+ echo "-> Cleaning Virtualenv"
+ cd build/usr/share/pgcli
+ virtualenv-tools --update-path /usr/share/pgcli > /dev/null
+ cd /home/vagrant/
+
+ echo "-> Removing compiled files"
+ find build -iname '*.pyc' -delete
+ find build -iname '*.pyo' -delete
+
+ cd /home/vagrant
+ echo "-> Creating PgCLI RPM"
+ /usr/local/rvm/gems/ruby-2.6.3/gems/fpm-1.12.0/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
+ -a all \
+ -d postgresql-devel \
+ -d python-devel \
+ -p /pgcli/ \
+ --after-install /pgcli/post-install \
+ --after-remove /pgcli/post-remove \
+ --url https://github.com/dbcli/pgcli \
+ --description "#{pgcli_description}" \
+ --license 'BSD'
+
+
+ SHELL
+
+
+ end
+
+
+end
+
diff --git a/changelog.rst b/changelog.rst
new file mode 100644
index 0000000..7d08839
--- /dev/null
+++ b/changelog.rst
@@ -0,0 +1,1203 @@
+==================
+4.0.1 (2023-11-30)
+==================
+
+Internal:
+---------
+* Allow stable version of pendulum.
+
+==================
+4.0.0 (2023-11-27)
+==================
+
+Features:
+---------
+
+* Ask for confirmation when quitting cli while a transaction is ongoing.
+* New `destructive_statements_require_transaction` config option to refuse to execute a
+ destructive SQL statement if outside a transaction. This option is off by default.
+* Changed the `destructive_warning` config to be a list of commands that are considered
+ destructive. This would allow you to be warned on `create`, `grant`, or `insert` queries.
+* Destructive warnings will now include the alias dsn connection string name if provided (-D option).
+* pgcli.magic will now work with connection URLs that use TLS client certificates for authentication
+* Have config option to retry queries on operational errors like connections being lost.
+ Also prevents getting stuck in a retry loop.
+* Config option to not restart connection when cancelling a `destructive_warning` query. By default,
+ it will now not restart.
+* Config option to always run with a single connection.
+* Add comment explaining default LESS environment variable behavior and change example pager setting.
+* Added `\echo` & `\qecho` special commands. ([issue 1335](https://github.com/dbcli/pgcli/issues/1335)).
+
+Bug fixes:
+----------
+
+* Fix `\ev` not producing a correctly quoted "schema"."view"
+* Fix 'invalid connection option "dsn"' ([issue 1373](https://github.com/dbcli/pgcli/issues/1373)).
+* Fix explain mode when used with `expand`, `auto_expand`, or `--explain-vertical-output` ([issue 1393](https://github.com/dbcli/pgcli/issues/1393)).
+* Fix sql-insert format emits NULL as 'None' ([issue 1408](https://github.com/dbcli/pgcli/issues/1408)).
+* Improve check for prompt-toolkit 3.0.6 ([issue 1416](https://github.com/dbcli/pgcli/issues/1416)).
+* Allow specifying an `alias_map_file` in the config that will use
+ predetermined table aliases instead of generating aliases programmatically on
+ the fly
+* Fixed SQL error when there is a comment on the first line: ([issue 1403](https://github.com/dbcli/pgcli/issues/1403))
+* Fix wrong usage of prompt instead of confirm when confirm execution of destructive query
+
+Internal:
+---------
+
+* Drop support for Python 3.7 and add 3.12.
+
+3.5.0 (2022/09/15):
+===================
+
+Features:
+---------
+
+* New formatter is added to export query result to sql format (such as sql-insert, sql-update) like mycli.
+
+Bug fixes:
+----------
+
+* Fix exception when retrieving password from keyring ([issue 1338](https://github.com/dbcli/pgcli/issues/1338)).
+* Fix using comments with special commands ([issue 1362](https://github.com/dbcli/pgcli/issues/1362)).
+* Small improvements to the Windows developer experience
+* Fix submitting queries in safe multiline mode ([1360](https://github.com/dbcli/pgcli/issues/1360)).
+
+Internal:
+---------
+
+* Port to psycopg3 (https://github.com/psycopg/psycopg).
+* Fix typos
+
+3.4.1 (2022/03/19)
+==================
+
+Bug fixes:
+----------
+
+* Fix the bug with Redshift not displaying word count in status ([related issue](https://github.com/dbcli/pgcli/issues/1320)).
+* Show the error status for CSV output format.
+
+
+3.4.0 (2022/02/21)
+==================
+
+Features:
+---------
+
+* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)).
+
+3.3.1 (2022/01/18)
+==================
+
+Bug fixes:
+----------
+
+* Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307.
+* Upgrade cli_helpers to 2.2.1
+
+3.3.0 (2022/01/11)
+==================
+
+Features:
+---------
+
+* Add `max_field_width` setting to config, to enable more control over field truncation ([related issue](https://github.com/dbcli/pgcli/issues/1250)).
+* Re-run last query via bare `\watch`. (Thanks: `Saif Hakim`_)
+
+Bug fixes:
+----------
+
+* Pin the version of pygments to prevent breaking change
+
+3.2.0
+=====
+
+Release date: 2021/08/23
+
+Features:
+---------
+
+* Consider `update` queries destructive and issue a warning. Change
+ `destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239)
+* Skip initial comment in .pg_session even if it doesn't start with '#'
+* Include functions from schemas in search_path. (`Amjith Ramanujam`_)
+* Easy way to show explain output under F5
+
+Bug fixes:
+----------
+
+* Fix issue where `syntax_style` config value would not have any effect. (#1212)
+* Fix crash because of not found `InputMode.REPLACE_SINGLE` with prompt-toolkit < 3.0.6
+* Fix comments being lost in config when saving a named query. (#1240)
+* Fix IPython magic for ipython-sql >= 0.4.0
+* Fix pager not being used when output format is set to csv. (#1238)
+* Add function literals random, generate_series, generate_subscripts
+* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly
+* Fix pgcli crashing with virtual `pgbouncer` database. (#1093)
+
+3.1.0
+=====
+
+Features:
+---------
+
+* Make the output more compact by removing the empty newline. (Thanks: `laixintao`_)
+* Add support for using [pspg](https://github.com/okbob/pspg) as a pager (#1102)
+* Update python version in Dockerfile
+* Support setting color for null, string, number, keyword value
+* Support Prompt Toolkit 2
+* Support sqlparse 0.4.x
+* Update functions, datatypes literals for auto-suggestion field
+* Add suggestion for schema in function auto-complete
+
+Bug fixes:
+----------
+
+* Minor typo fixes in `pgclirc`. (Thanks: `anthonydb`_)
+* Fix for list index out of range when executing commands from a file (#1193). (Thanks: `Irina Truong`_)
+* Move from `humanize` to `pendulum` for displaying query durations (#1015)
+* More explicit error message when connecting using DSN alias and it is not found.
+
+3.0.0
+=====
+
+Features:
+---------
+
+* Add `__main__.py` file to execute pgcli as a package directly (#1123).
+* Add support for ANSI escape sequences for coloring the prompt (#1122).
+* Add support for partitioned tables (relkind "p").
+* Add support for `pg_service.conf` files
+* Add config option show_bottom_toolbar.
+
+Bug fixes:
+----------
+
+* Fix warning raised for using `is not` to compare string literal
+* Close open connection in completion_refresher thread
+
+Internal:
+---------
+
+* Drop Python2.7, 3.4, 3.5 support. (Thanks: `laixintao`_)
+* Support Python3.8. (Thanks: `laixintao`_)
+* Fix dead link in development guide. (Thanks: `BrownShibaDog`_)
+* Upgrade python-prompt-toolkit to v3.0. (Thanks: `laixintao`_)
+
+
+2.2.0:
+======
+
+Features:
+---------
+
+* Add `\\G` as a terminator to sql statements that will show the results in expanded mode. This feature is copied from mycli. (Thanks: `Amjith Ramanujam`_)
+* Removed limit prompt and added automatic row limit on queries with no LIMIT clause (#1079) (Thanks: `Sebastian Janko`_)
+* Function argument completions now take account of table aliases (#1048). (Thanks: `Owen Stephens`_)
+
+Bug fixes:
+----------
+
+* Error connecting to PostgreSQL 12beta1 (#1058). (Thanks: `Irina Truong`_ and `Amjith Ramanujam`_)
+* Empty query caused error message (#1019) (Thanks: `Sebastian Janko`_)
+* History navigation bindings in multiline queries (#1004) (Thanks: `Pedro Ferrari`_)
+* Can't connect to pgbouncer database (#1093). (Thanks: `Irina Truong`_)
+* Fix broken multi-line history search (#1031). (Thanks: `Owen Stephens`_)
+* Fix slow typing/movement when multi-line query ends in a semicolon (#994). (Thanks: `Owen Stephens`_)
+* Fix for PQconninfo not available in libpq < 9.3 (#1110). (Thanks: `Irina Truong`_)
+
+Internal:
+---------
+
+* Add optional but default squash merge request to PULL_REQUEST_TEMPLATE
+
+2.1.1
+=====
+
+Bug fixes:
+----------
+* Escape switches to VI navigation mode when not canceling completion popup. (Thanks: `Nathan Verzemnieks`_)
+* Allow application_name to be overridden. (Thanks: `raylu`_)
+* Fix for "no attribute KeyringLocked" (#1040). (Thanks: `Irina Truong`_)
+* Pgcli no longer works with password containing spaces (#1043). (Thanks: `Irina Truong`_)
+* Load keyring only when keyring is enabled in the config file (#1041). (Thanks: `Zhaolong Zhu`_)
+* No longer depend on sqlparse as being less than 0.3.0 with the release of sqlparse 0.3.0. (Thanks: `VVelox`_)
+* Fix the broken support for pgservice . (Thanks: `Xavier Francisco`_)
+* Connecting using socket is broken in current master. (#1053). (Thanks: `Irina Truong`_)
+* Allow usage of newer versions of psycopg2 (Thanks: `Telmo "Trooper"`_)
+* Update README in alignment with the usage of newer versions of psycopg2 (Thanks: `Alexander Zawadzki`_)
+
+Internal:
+---------
+
+* Add python 3.7 to travis build matrix. (Thanks: `Irina Truong`_)
+* Apply `black` to code. (Thanks: `Irina Truong`_)
+
+2.1.0
+=====
+
+Features:
+---------
+
+* Keybindings for closing the autocomplete list. (Thanks: `easteregg`_)
+* Reconnect automatically when server closes connection. (Thanks: `Scott Brenstuhl`_)
+
+Bug fixes:
+----------
+* Avoid error message on the server side if hstore extension is not installed in the current database (#991). (Thanks: `Marcin Cieślak`_)
+* All pexpect submodules have been moved into the pexpect package as of version 3.0. Use pexpect.TIMEOUT (Thanks: `Marcin Cieślak`_)
+* Resizing pgcli terminal kills the connection to postgres in python 2.7 (Thanks: `Amjith Ramanujam`_)
+* Fix crash retrieving server version with ``--single-connection``. (Thanks: `Irina Truong`_)
+* Cannot quit application without reconnecting to database (#1014). (Thanks: `Irina Truong`_)
+* Password authentication failed for user "postgres" when using non-default password (#1020). (Thanks: `Irina Truong`_)
+
+Internal:
+---------
+
+* (Fixup) Clean up and add behave logging. (Thanks: `Marcin Cieślak`_, `Dick Marinus`_)
+* Override VISUAL environment variable for behave tests. (Thanks: `Marcin Cieślak`_)
+* Remove build dir before running sdist, remove stray files from wheel distribution. (Thanks: `Dick Marinus`_)
+* Fix unit tests, unhashable formatted text since new python prompttoolkit version. (Thanks: `Dick Marinus`_)
+
+2.0.2:
+======
+
+Features:
+---------
+
+* Allows passing the ``-u`` flag to specify a username. (Thanks: `Ignacio Campabadal`_)
+* Fix for lag in v2 (#979). (Thanks: `Irina Truong`_)
+* Support for multihost connection string that is convenient if you have postgres cluster. (Thanks: `Mikhail Elovskikh`_)
+
+Internal:
+---------
+
+* Added tests for special command completion. (Thanks: `Amjith Ramanujam`_)
+
+2.0.1:
+======
+
+Bug fixes:
+----------
+
+* Tab press on an empty line increases the indentation instead of triggering
+ the auto-complete pop-up. (Thanks: `Artur Balabanov`_)
+* Fix for loading/saving named queries from provided config file (#938). (Thanks: `Daniel Egger`_)
+* Set default port in `connect_uri` when none is given. (Thanks: `Daniel Egger`_)
+* Fix for error listing databases (#951). (Thanks: `Irina Truong`_)
+* Enable Ctrl-Z to suspend the app (Thanks: `Amjith Ramanujam`_).
+* Fix StopIteration exception raised at runtime for Python 3.7 (Thanks: `Amjith Ramanujam`_).
+
+Internal:
+---------
+
+* Clean up and add behave logging. (Thanks: `Dick Marinus`_)
+* Require prompt_toolkit>=2.0.6. (Thanks: `Dick Marinus`_)
+* Improve development guide. (Thanks: `Ignacio Campabadal`_)
+
+2.0.0:
+======
+
+* Update to ``prompt-toolkit`` 2.0. (Thanks: `Jonathan Slenders`_, `Dick Marinus`_, `Irina Truong`_)
+
+1.11.0
+======
+
+Features:
+---------
+
+* Respect `\pset pager on` and use pager when output is longer than terminal height (Thanks: `Max Rothman`_)
+
+1.10.3
+======
+
+Bug fixes:
+----------
+
+* Adapt the query used to get functions metadata to PG11 (#919). (Thanks: `Lele Gaifax`_).
+* Fix for error retrieving version in Redshift (#922). (Thanks: `Irina Truong`_)
+* Fix for keyring not disabled properly (#920). (Thanks: `Irina Truong`_)
+
+1.10.2
+======
+
+Features:
+---------
+
+* Make `keyring` optional (Thanks: `Dick Marinus`_)
+
+1.10.1
+======
+
+Bug fixes:
+----------
+
+* Fix for missing keyring. (Thanks: `Kenny Do`_)
+* Fix for "-l" Flag Throws Error (#909). (Thanks: `Irina Truong`_)
+
+1.10.0
+======
+
+Features:
+---------
+* Add quit commands to the completion menu. (Thanks: `Jason Ribeiro`_)
+* Add table formats to ``\T`` completion. (Thanks: `Jason Ribeiro`_)
+* Support `\\ev``, ``\ef`` (#754). (Thanks: `Catherine Devlin`_)
+* Add ``application_name`` to help identify pgcli connection to database (issue #868) (Thanks: `François Pietka`_)
+* Add `--user` option, duplicate of `--username`, the same cli option like `psql` (Thanks: `Alexandr Korsak`_)
+
+Internal changes:
+-----------------
+
+* Mark tests requiring a running database server as dbtest (Thanks: `Dick Marinus`_)
+* Add an is_special command flag to MetaQuery (Thanks: `Rishi Ramraj`_)
+* Ported Destructive Warning from mycli.
+* Refactor Destructive Warning behave tests (Thanks: `Dick Marinus`_)
+
+Bug Fixes:
+----------
+* Disable pager when using \watch (#837). (Thanks: `Jason Ribeiro`_)
+* Don't offer to reconnect when we can't change a param in realtime (#807). (Thanks: `Amjith Ramanujam`_ and `Saif Hakim`_)
+* Make keyring optional. (Thanks: `Dick Marinus`_)
+* Fix ipython magic connection (#891). (Thanks: `Irina Truong`_)
+* Fix not enough values to unpack. (Thanks: `Matthieu Guilbert`_)
+* Fix unbound local error when destructive_warning is false. (Thanks: `Matthieu Guilbert`_)
+* Render tab characters as 4 spaces instead of `^I`. (Thanks: `Artur Balabanov`_)
+
+1.9.1:
+======
+
+Features:
+---------
+
+* Change ``\h`` format string in prompt to only return the first part of the hostname,
+ up to the first '.' character. Add ``\H`` that returns the entire hostname (#858).
+ (Thanks: `Andrew Kuchling`_)
+* Add Color of table by parameter. The color of table is function of syntax style
+
+Internal changes:
+-----------------
+
+* Add tests, AUTHORS and changelog.rst to release. (Thanks: `Dick Marinus`_)
+
+Bug Fixes:
+----------
+* Fix broken pgcli --list command line option (#850). (Thanks: `Dmitry B`_)
+
+1.9.0
+=====
+
+Features:
+---------
+
+* manage pager by \pset pager and add enable_pager to the config file (Thanks: `Frederic Aoustin`_).
+* Add support for `\T` command to change format output. (Thanks: `Frederic Aoustin`_).
+* Add option list-dsn (Thanks: `Frederic Aoustin`_).
+
+
+Internal changes:
+-----------------
+
+* Removed support for Python 3.3. (Thanks: `Irina Truong`_)
+
+1.8.2
+=====
+
+Features:
+---------
+
+* Use other prompt (prompt_dsn) when connecting using --dsn parameter. (Thanks: `Marcin Sztolcman`_)
+* Include username into password prompt. (Thanks: `Bojan Delić`_)
+
+Internal changes:
+-----------------
+* Use temporary dir as config location in tests. (Thanks: `Dmitry B`_)
+* Fix errors in the ``tee`` test (#795 and #797). (Thanks: `Irina Truong`_)
+* Increase timeout for quitting pgcli. (Thanks: `Dick Marinus`_)
+
+Bug Fixes:
+----------
+* Do NOT quote the database names in the completion menu (Thanks: `Amjith Ramanujam`_)
+* Fix error in ``unix_socket_directories`` (#805). (Thanks: `Irina Truong`_)
+* Fix the --list command line option tries to connect to 'personal' DB (#816). (Thanks: `Isank`_)
+
+1.8.1
+=====
+
+Internal changes:
+-----------------
+* Remove shebang and git execute permission from pgcli/main.py. (Thanks: `Dick Marinus`_)
+* Require cli_helpers 0.2.3 (fix #791). (Thanks: `Dick Marinus`_)
+
+1.8.0
+=====
+
+Features:
+---------
+
+* Add fish-style auto-suggestion from history. (Thanks: `Amjith Ramanujam`_)
+* Improved formatting of arrays in output (Thanks: `Joakim Koljonen`_)
+* Don't quote identifiers that are non-reserved keywords. (Thanks: `Joakim Koljonen`_)
+* Remove the ``...`` in the continuation prompt and use empty space instead. (Thanks: `Amjith Ramanujam`_)
+* Add \conninfo and handle more parameters with \c (issue #716) (Thanks: `François Pietka`_)
+
+Internal changes:
+-----------------
+* Preliminary work for a future change in outputting results that uses less memory. (Thanks: `Dick Marinus`_)
+* Remove import workaround for OrderedDict, required for python < 2.7. (Thanks: `Andrew Speed`_)
+* Use less memory when formatting results for display (Thanks: `Dick Marinus`_).
+* Port auto_vertical feature test from mycli to pgcli. (Thanks: `Dick Marinus`_)
+* Drop wcwidth dependency (Thanks: `Dick Marinus`_)
+
+Bug Fixes:
+----------
+
+* Fix the way we get host when using DSN (issue #765) (Thanks: `François Pietka`_)
+* Add missing keyword COLUMN after DROP (issue #769) (Thanks: `François Pietka`_)
+* Don't include arguments in function suggestions for backslash commands (Thanks: `Joakim Koljonen`_)
+* Optionally use POSTGRES_USER, POSTGRES_HOST POSTGRES_PASSWORD from environment (Thanks: `Dick Marinus`_)
+
+1.7.0
+=====
+
+* Refresh completions after `COMMIT` or `ROLLBACK`. (Thanks: `Irina Truong`_)
+* Fixed DSN aliases not being read from custom pgclirc (issue #717). (Thanks: `Irina Truong`_).
+* Use dbcli's Homebrew tap for installing pgcli on macOS (issue #718) (Thanks: `Thomas Roten`_).
+* Only set `LESS` environment variable if it's unset. (Thanks: `Irina Truong`_)
+* Quote schema in `SET SCHEMA` statement (issue #469) (Thanks: `Irina Truong`_)
+* Include arguments in function suggestions (Thanks: `Joakim Koljonen`_)
+* Use CLI Helpers for pretty printing query results (Thanks: `Thomas Roten`_).
+* Skip serial columns when expanding * for `INSERT INTO foo(*` (Thanks: `Joakim Koljonen`_).
+* Command line option to list databases (issue #206) (Thanks: `François Pietka`_)
+
+1.6.0
+=====
+
+Features:
+---------
+* Add time option for prompt (Thanks: `Gustavo Castro`_)
+* Suggest objects from all schemas (not just those in search_path) (Thanks: `Joakim Koljonen`_)
+* Casing for column headers (Thanks: `Joakim Koljonen`_)
+* Allow configurable character to be used for multi-line query continuations. (Thanks: `Owen Stephens`_)
+* Completions after ORDER BY and DISTINCT now take account of table aliases. (Thanks: `Owen Stephens`_)
+* Narrow keyword candidates based on previous keyword. (Thanks: `Étienne Bersac`_)
+* Opening an external editor will edit the last-run query. (Thanks: `Thomas Roten`_)
+* Support query options in postgres URIs such as ?sslcert=foo.pem (Thanks: `Alexander Schmolck`_)
+
+Bug fixes:
+----------
+* Fixed external editor bug (issue #668). (Thanks: `Irina Truong`_).
+* Standardize command line option names. (Thanks: `Russell Davies`_)
+* Improve handling of ``lock_not_available`` error (issue #700). (Thanks: `Jackson Popkin <https://github.com/jdpopkin>`_)
+* Fixed user option precedence (issue #697). (Thanks: `Irina Truong`_).
+
+Internal changes:
+-----------------
+* Run pep8 checks in travis (Thanks: `Irina Truong`_).
+* Add pager wrapper for behave tests (Thanks: `Dick Marinus`_).
+* Behave quit pgcli nicely (Thanks: `Dick Marinus`_).
+* Behave test source command (Thanks: `Dick Marinus`_).
+* Behave fix clean up. (Thanks: `Dick Marinus`_).
+* Test using behave the tee command (Thanks: `Dick Marinus`_).
+* Behave remove boiler plate code (Thanks: `Dick Marinus`_).
+* Behave fix pgspecial update (Thanks: `Dick Marinus`_).
+* Add behave to tox (Thanks: `Dick Marinus`_).
+
+1.5.1
+=====
+
+Features:
+---------
+* Better suggestions when editing functions (Thanks: `Joakim Koljonen`_)
+* Command line option for ``--less-chatty``. (Thanks: `tk`_)
+* Added ``MATERIALIZED VIEW`` keywords. (Thanks: `Joakim Koljonen`_).
+
+Bug fixes:
+----------
+
+* Support unicode chars in expanded mode. (Thanks: `Amjith Ramanujam`_)
+* Fixed "set_session cannot be used inside a transaction" when using dsn. (Thanks: `Irina Truong`_).
+
+1.5.0
+=====
+
+Features:
+---------
+* Upgraded pgspecial to 1.7.0. (See `pgspecial changelog <https://github.com/dbcli/pgspecial/blob/master/changelog.rst>`_ for list of fixes)
+* Add a new config setting to allow expandable mode (Thanks: `Jonathan Boudreau <https://github.com/AGhost-7>`_)
+* Make pgcli prompt width short when the prompt is too long (Thanks: `Jonathan Virga <https://github.com/jnth>`_)
+* Add additional completion for ``ALTER`` keyword (Thanks: `Darik Gamble`_)
+* Make the menu size configurable. (Thanks `Darik Gamble`_)
+
+Bug Fixes:
+----------
+* Handle more connection failure cases. (Thanks: `Amjith Ramanujam`_)
+* Fix the connection failure issues with latest psycopg2. (Thanks: `Amjith Ramanujam`_)
+
+Internal Changes:
+-----------------
+
+* Add testing for Python 3.5 and 3.6. (Thanks: `Amjith Ramanujam`_)
+
+1.4.0
+=====
+
+Features:
+---------
+
+* Search table suggestions using initialisms. (Thanks: `Joakim Koljonen`_).
+* Support for table-qualifying column suggestions. (Thanks: `Joakim Koljonen`_).
+* Display transaction status in the toolbar. (Thanks: `Joakim Koljonen`_).
+* Display vi mode in the toolbar. (Thanks: `Joakim Koljonen`_).
+* Added --prompt option. (Thanks: `Irina Truong`_).
+
+Bug Fixes:
+----------
+
+* Fix scoping for columns from CTEs. (Thanks: `Joakim Koljonen`_)
+* Fix crash after `with`. (Thanks: `Joakim Koljonen`_).
+* Fix issue #603 (`\i` raises a TypeError). (Thanks: `Lele Gaifax`_).
+
+
+Internal Changes:
+-----------------
+
+* Set default data_formatting to nothing. (Thanks: `Amjith Ramanujam`_).
+* Increased minimum prompt_toolkit requirement to 1.0.9. (Thanks: `Irina Truong`_).
+
+
+1.3.1
+=====
+
+Bug Fixes:
+----------
+* Fix a crashing bug due to sqlparse upgrade. (Thanks: `Darik Gamble`_)
+
+
+1.3.0
+=====
+
+IMPORTANT: Python 2.6 is not officially supported anymore.
+
+Features:
+---------
+* Add delimiters to displayed numbers. This can be configured via the config file. (Thanks: `Sergii`_).
+* Fix broken 'SHOW ALL' in redshift. (Thanks: `Manuel Barkhau`_).
+* Support configuring keyword casing preferences. (Thanks: `Darik Gamble`_).
+* Add a new multi_line_mode option in config file. The values can be `psql` or `safe`. (Thanks: `Joakim Koljonen`_)
+ Setting ``multi_line_mode = safe`` will make sure that a query will only be executed when Alt+Enter is pressed.
+
+Bug Fixes:
+----------
+* Fix crash bug with leading parenthesis. (Thanks: `Joakim Koljonen`_).
+* Remove cumulative addition of timing data. (Thanks: `Amjith Ramanujam`_).
+* Handle unrecognized keywords gracefully. (Thanks: `Darik Gamble`_)
+* Use raw strings in regex specifiers. This preemptively fixes a crash in Python 3.6. (Thanks `Lele Gaifax`_)
+
+Internal Changes:
+-----------------
+* Set sqlparse version dependency to >0.2.0, <0.3.0. (Thanks: `Amjith Ramanujam`_).
+* XDG_CONFIG_HOME support for config file location. (Thanks: `Fabien Meghazi`_).
+* Remove Python 2.6 from travis test suite. (Thanks: `Amjith Ramanujam`_)
+
+1.2.0
+=====
+
+Features:
+---------
+
+* Add more specifiers to pgcli prompt. (Thanks: `Julien Rouhaud`_).
+ ``\p`` for port info ``\#`` for super user and ``\i`` for pid.
+* Add `\watch` command to periodically execute a command. (Thanks: `Stuart Quin`_).
+ ``> SELECT * FROM django_migrations; \watch 1 /* Runs the command every second */``
+* Add command-line option --single-connection to prevent pgcli from using multiple connections. (Thanks: `Joakim Koljonen`_).
+* Add priority to the suggestions to sort based on relevance. (Thanks: `Joakim Koljonen`_).
+* Configurable null format via the config file. (Thanks: `Adrian Dries`_).
+* Add support for CTE aware auto-completion. (Thanks: `Darik Gamble`_).
+* Add host and user information to default pgcli prompt. (Thanks: `Lim H`_).
+* Better scoping for tables in insert statements to improve suggestions. (Thanks: `Joakim Koljonen`_).
+
+Bug Fixes:
+----------
+
+* Do not install setproctitle on cygwin. (Thanks: `Janus Troelsen`_).
+* Work around sqlparse crashing after AS keyword. (Thanks: `Joakim Koljonen`_).
+* Fix a crashing bug with named queries. (Thanks: `Joakim Koljonen`_).
+* Replace timestampz alias since AWS Redshift does not support it. (Thanks: `Tahir Butt`_).
+* Prevent pgcli from hanging indefinitely when Postgres instance is not running. (Thanks: `Darik Gamble`_)
+
+Internal Changes:
+-----------------
+
+* Upgrade to sqlparse-0.2.0. (Thanks: `Tiziano Müller`_).
+* Upgrade to pgspecial 1.6.0. (Thanks: `Stuart Quin`_).
+
+
+1.1.0
+=====
+
+Features:
+---------
+
+* Add support for ``\db`` command. (Thanks: `Irina Truong`_)
+
+Bugs:
+-----
+
+* Fix the crash at startup while parsing the postgres url with port number. (Thanks: `Eric Wald`_)
+* Fix the crash with Redshift databases. (Thanks: `Darik Gamble`_)
+
+Internal Changes:
+-----------------
+
+* Upgrade pgspecial to 1.5.0 and above.
+
+1.0.0
+=====
+
+Features:
+---------
+
+* Upgrade to prompt-toolkit 1.0.0. (Thanks: `Jonathan Slenders`_).
+* Add support for `\o` command to redirect query output to a file. (Thanks: `Tim Sanders`_).
+* Add `\i` path completion. (Thanks: `Anthony Lai`_).
+* Connect to a dsn saved in config file. (Thanks: `Rodrigo Ramírez Norambuena`_).
+* Upgrade sqlparse requirement to version 0.1.19. (Thanks: `Fernando L. Canizo`_).
+* Add timestamptz to DATE custom extension. (Thanks: `Fernando Mora`_).
+* Ensure target dir exists when copying config. (Thanks: `David Szotten`_).
+* Handle dates that fall in the B.C. range. (Thanks: `Stuart Quin`_).
+* Pager is selected from config file or else from environment variable. (Thanks: `Fernando Mora`_).
+* Add support for Amazon Redshift. (Thanks: `Timothy Cleaver`_).
+* Add support for Postgres 8.x. (Thanks: `Timothy Cleaver`_ and `Darik Gamble`_)
+* Don't error when completing parameter-less functions. (Thanks: `David Szotten`_).
+* Concat and return all available notices. (Thanks: `Stuart Quin`_).
+* Handle unicode in record type. (Thanks: `Amjith Ramanujam`_).
+* Added humanized time display. Connect #396. (Thanks: `Irina Truong`_).
+* Add EXPLAIN keyword to the completion list. (Thanks: `Amjith Ramanujam`_).
+* Added sdist upload to release script. (Thanks: `Irina Truong`_).
+* Sort completions based on most recently used. (Thanks: `Darik Gamble`)
+* Expand '*' into column list during completion. This can be triggered by hitting `<tab>` after the '*' character in the sql while typing. (Thanks: `Joakim Koljonen`_)
+* Add a limit to the warning about too many rows. This is controlled by a new config value in ~/.config/pgcli/config. (Thanks: `Anže Pečar`_)
+* Improved argument list in function parameter completions. (Thanks: `Joakim Koljonen`_)
+* Column suggestions after the COLUMN keyword. (Thanks: `Darik Gamble`_)
+* Filter out trigger implemented functions from the suggestion list. (Thanks: `Daniel Rocco`_)
+* State of the art JOIN clause completions that suggest entire conditions. (Thanks: `Joakim Koljonen`_)
+* Suggest fully formed JOIN clauses based on Foreign Key relations. (Thanks: `Joakim Koljonen`_)
+* Add support for `\dx` meta command to list the installed extensions. (Thanks: `Darik Gamble`_)
+* Add support for `\copy` command. (Thanks: `Catherine Devlin`_)
+
+Bugs:
+-----
+
+* Fix bug where config writing would leave a '~' dir. (Thanks: `James Munson`_).
+* Fix auto-completion breaking for table names with caps. (Thanks: `Anthony Lai`_).
+* Fix lexical ordering bug. (Thanks: `Anthony Lai`_).
+* Use lexical order to break ties when fuzzy matching. (Thanks: `Daniel Rocco`_).
+* Fix the bug in auto-expand mode when there are no rows to display. (Thanks: `Amjith Ramanujam`_).
+* Fix broken `\i` after #395. (Thanks: `David Szotten`_).
+* Fix multi-way joins in auto-completion. (Thanks: `Darik Gamble`_)
+* Display null values as <null> in expanded output. (Thanks: `Amjith Ramanujam`_).
+* Robust support for Postgres version less than 9.x. (Thanks: `Darik Gamble`_)
+
+Internal Changes:
+-----------------
+
+* Update config file location in README. (Thanks: `Ari Summer`_).
+* Explicitly add wcwidth as a dependency. (Thanks: `Amjith Ramanujam`_).
+* Add tests for the format_output. (Thanks: `Amjith Ramanujam`_).
+* Lots of tests for pgcompleter. (Thanks: `Darik Gamble`_).
+* Update pgspecial dependency to 1.4.0.
+
+
+0.20.1
+======
+
+Bug Fixes:
+----------
+* Fixed logging in Windows by switching the location of log and history file based on OS. (Thanks: Amjith, `Darik Gamble`_, `Irina Truong`_).
+
+0.20.0
+======
+
+Features:
+---------
+* Perform auto-completion refresh in background. (Thanks: Amjith, `Darik Gamble`_, `Irina Truong`_).
+ When the auto-completion entries are refreshed, the update now happens in a
+ background thread. This means large databases with thousands of tables are
+ handled without blocking.
+* Add ``CONCURRENTLY`` to keyword completion. (Thanks: `Johannes Hoff`_).
+* Add support for ``\h`` command. (Thanks: `Stuart Quin`_).
+ This is a huge deal. Users can now get help on an SQL command by typing:
+ ``\h COMMAND_NAME`` in the pgcli prompt.
+* Add support for ``\x auto``. (Thanks: `Stuart Quin`_).
+ ``\\x auto`` will automatically switch to expanded mode if the output is wider
+ than the display window.
+* Don't hide functions from pg_catalog. (Thanks: `Darik Gamble`_).
+* Suggest set-returning functions as tables. (Thanks: `Darik Gamble`_).
+ Functions that return table like results will now be suggested in places of tables.
+* Suggest fields from functions used as tables. (Thanks: `Darik Gamble`_).
+* Using ``pgspecial`` as a separate module. (Thanks: `Irina Truong`_).
+* Make "enter" key behave as "tab" key when the completion menu is displayed. (Thanks: `Matheus Rosa`_).
+* Support different error-handling options when running multiple queries. (Thanks: `Darik Gamble`_).
+ When ``on_error = STOP`` in the config file, pgcli will abort execution if one of the queries results in an error.
+* Hide the password displayed in the process name in ``ps``. (Thanks: `Stuart Quin`_)
+
+Bug Fixes:
+----------
+* Fix the ordering bug in `\\d+` display, this bug was displaying the wrong table name in the reference. (Thanks: `Tamas Boros`_).
+* Only show expanded layout if valid list of headers provided. (Thanks: `Stuart Quin`_).
+* Fix suggestions in compound join clauses. (Thanks: `Darik Gamble`_).
+* Fix completion refresh in multiple query scenario. (Thanks: `Darik Gamble`_).
+* Fix the broken timing information.
+* Fix the removal of whitespaces in the output. (Thanks: `Jacek Wielemborek`_)
+* Fix PyPI badge. (Thanks: `Artur Dryomov`_).
+
+Improvements:
+-------------
+* Move config file to `~/.config/pgcli/config` instead of `~/.pgclirc` (Thanks: `inkn`_).
+* Move literal definitions to standalone JSON files. (Thanks: `Darik Gamble`_).
+
+Internal Changes:
+-----------------
+* Improvements to integration tests to make it more robust. (Thanks: `Irina Truong`_).
+
+0.19.2
+======
+
+Features:
+---------
+
+* Autocompletion for database name in \c and \connect. (Thanks: `Darik Gamble`_).
+* Improved multiline query support by correctly handling open quotes. (Thanks: `Darik Gamble`_).
+* Added \pager command.
+* Enhanced \i to run multiple queries and display the results for each of them
+* Added keywords to suggestions after WHERE clause.
+* Enabled autocompletion in named queries. (Thanks: `Irina Truong`_).
+* Path to .pgclirc can be specified in command line. (Thanks: `Irina Truong`_).
+* Added support for pg_service_conf file. (Thanks: `Irina Truong`_).
+* Added custom styles. (Contributor: `Darik Gamble`_).
+
+Internal Changes:
+-----------------
+
+* More completer test cases. (Thanks: `Darik Gamble`_).
+* Updated sqlparse version from 0.1.14 to 0.1.16. (Thanks: `Darik Gamble`_).
+* Upgraded to prompt_toolkit 0.46. (Thanks: `Jonathan Slenders`_).
+
+BugFixes:
+---------
+* Fixed the completer crashing on invalid SQL. (Thanks: `Darik Gamble`_).
+* Fixed unicode issues, updated tests and fixed broken tests.
+
+0.19.1
+======
+
+BugFixes:
+---------
+
+* Fix an autocompletion bug that was crashing the completion engine when unknown keyword is entered. (Thanks: `Darik Gamble`_)
+
+0.19.0
+======
+
+Features:
+---------
+
+* Wider completion menus can be enabled via the config file. (Thanks: `Jonathan Slenders`_)
+
+ Open the config file (~/.pgclirc) and check if you have
+ ``wider_completion_menu`` option available. If not add it in and set it to
+ ``True``.
+
+* Completion menu now has metadata information such as schema, table, column, view, etc., next to the suggestions. (Thanks: `Darik Gamble`_)
+* Customizable history file location via config file. (Thanks: `Çağatay Yüksel`_)
+
+ Add this line to your config file (~/.pgclirc) to customize where to store the history file.
+
+::
+
+ history_file = /path/to/history/file
+
+* Add support for running queries from a file using ``\i`` special command. (Thanks: `Michael Kaminsky`_)
+
+BugFixes:
+---------
+
+* Always use utf-8 for database encoding regardless of the default encoding used by the database.
+* Fix for None dereference on ``\d schemaname.`` with sequence. (Thanks: `Nathan Jhaveri`_)
+* Fix a crashing bug in the autocompletion engine for some ``JOIN`` queries.
+* Handle KeyboardInterrupt in pager and not quit pgcli as a consequence.
+
+Internal Changes:
+-----------------
+
+* Added more behaviorial tests (Thanks: `Irina Truong`_)
+* Added code coverage to the tests. (Thanks: `Irina Truong`_)
+* Run behaviorial tests as part of TravisCI (Thanks: `Irina Truong`_)
+* Upgraded prompt_toolkit version to 0.45 (Thanks: `Jonathan Slenders`_)
+* Update the minimum required version of click to 4.1.
+
+0.18.0
+======
+
+Features:
+---------
+
+* Add fuzzy matching for the table names and column names.
+
+ Matching very long table/column names are now easier with fuzzy matching. The
+ fuzzy match works like the fuzzy open in SublimeText or Vim's Ctrl-P plugin.
+
+ eg: Typing ``djmv`` will match `django_migration_views` since it is able to
+ match parts of the input to the full table name.
+
+* Change the timing information to seconds.
+
+ The ``Command Time`` and ``Format Time`` are now displayed in seconds instead
+ of a unitless number displayed in scientific notation.
+
+* Support for named queries (favorite queries). (Thanks: `Brett Atoms`_)
+
+ Frequently typed queries can now be saved and recalled using a name using
+ newly added special commands (``\n[+]``, ``\ns``, ``\nd``).
+
+ eg:
+
+::
+
+ # Save a query
+ pgcli> \ns simple select * from foo
+ saved
+
+ # List all saved queries
+ pgcli> \n+
+
+ # Execute a saved query
+ pgcli> \n simple
+
+ # Delete a saved query
+ pgcli> \nd simple
+
+* Pasting queries into the pgcli repl is orders of magnitude faster. (Thanks: `Jonathan Slenders`_)
+
+* Add support for PGPASSWORD environment variable to pass the password for the
+ postgres database. (Thanks: `Irina Truong`_)
+
+* Add the ability to manually refresh autocompletions by typing ``\#`` or
+ ``\refresh``. This is useful if the database was updated by an external means
+ and you'd like to refresh the auto-completions to pick up the new change.
+
+Bug Fixes:
+----------
+
+* Fix an error when running ``\d table_name`` when running on a table with rules. (Thanks: `Ali Kargın`_)
+* Fix a pgcli crash when entering non-ascii characters in Windows. (Thanks: `Darik Gamble`_, `Jonathan Slenders`_)
+* Faster rendering of expanded mode output by making the horizontal separator a fixed length string.
+* Completion suggestions for the ``\c`` command are not auto-escaped by default.
+
+Internal Changes:
+-----------------
+
+* Complete refactor of handling the back-slash commands.
+* Upgrade prompt_toolkit to 0.42. (Thanks: `Jonathan Slenders`_)
+* Change the config file management to use ConfigObj.(Thanks: `Brett Atoms`_)
+* Add integration tests using ``behave``. (Thanks: `Irina Truong`_)
+
+0.17.0
+======
+
+Features:
+---------
+
+* Add support for auto-completing view names. (Thanks: `Darik Gamble`_)
+* Add support for building RPM and DEB packages. (Thanks: dp_)
+* Add subsequence matching for completion. (Thanks: `Daniel Rocco`_)
+ Previously completions only matched a table name if it started with the
+ partially typed word. Now completions will match even if the partially typed
+ word is in the middle of a suggestion.
+ eg: When you type 'mig', 'django_migrations' will be suggested.
+* Completion for built-in tables and temporary tables are suggested after entering a prefix of ``pg_``. (Thanks: `Darik Gamble`_)
+* Add place holder doc strings for special commands that are planned for implementation. (Thanks: `Irina Truong`_)
+* Updated version of prompt_toolkit, now matching braces are highlighted. (Thanks: `Jonathan Slenders`_)
+* Added support of ``\\e`` command. Queries can be edited in an external editor. (Thanks: `Irina Truong`_)
+ eg: When you type ``SELECT * FROM \e`` it will be opened in an external editor.
+* Add special command ``\dT`` to show datatypes. (Thanks: `Darik Gamble`_)
+* Add auto-completion support for datatypes in CREATE, SELECT etc. (Thanks: `Darik Gamble`_)
+* Improve the auto-completion in WHERE clause with logical operators. (Thanks: `Darik Gamble`_)
+*
+
+Bug Fixes:
+----------
+
+* Fix the table formatting while printing multi-byte characters (Chinese, Japanese etc). (Thanks: `蔡佳男`_)
+* Fix a crash when pg_catalog was present in search path. (Thanks: `Darik Gamble`_)
+* Fixed a bug that broke `\\e` when prompt_tookit was updated. (Thanks: `François Pietka`_)
+* Fix the display of triggers as shown in the ``\d`` output. (Thanks: `Dimitar Roustchev`_)
+* Fix broken auto-completion for INNER JOIN, LEFT JOIN etc. (Thanks: `Darik Gamble`_)
+* Fix incorrect super() calls in pgbuffer, pgtoolbar and pgcompleter. No change in functionality but protects against future problems. (Thanks: `Daniel Rocco`_)
+* Add missing schema completion for CREATE and DROP statements. (Thanks: `Darik Gamble`_)
+* Minor fixes around cursor cleanup.
+
+0.16.3
+======
+
+Bug Fixes:
+----------
+* Add more SQL keywords for auto-complete suggestion.
+* Messages raised as part of stored procedures are no longer ignored.
+* Use postgres flavored syntax highlighting instead of generic ANSI SQL.
+
+0.16.2
+======
+
+Bug Fixes:
+----------
+* Fix a bug where the schema qualifier was ignored by the auto-completion.
+ As a result the suggestions for tables vs functions are cleaner. (Thanks: `Darik Gamble`_)
+* Remove scientific notation when formatting large numbers. (Thanks: `Daniel Rocco`_)
+* Add the FUNCTION keyword to auto-completion.
+* Display NULL values as <null> instead of empty strings.
+* Fix the completion refresh when ``\connect`` is executed.
+
+0.16.1
+======
+
+Bug Fixes:
+----------
+* Fix unicode issues with hstore.
+* Fix a silent error when database is changed using \\c.
+
+0.16.0
+======
+
+Features:
+---------
+* Add \ds special command to show sequences.
+* Add Vi mode for keybindings. This can be enabled by adding 'vi = True' in ~/.pgclirc. (Thanks: `Jay Zeng`_)
+* Add a -v/--version flag to pgcli.
+* Add completion for TEMPLATE keyword and smart-completion for
+ 'CREATE DATABASE blah WITH TEMPLATE <tab>'. (Thanks: `Daniel Rocco`_)
+* Add custom decoders to json/jsonb to emulate the behavior of psql. This
+ removes the unicode prefix (eg: u'Éowyn') in the output. (Thanks: `Daniel Rocco`_)
+* Add \df special command to show functions. (Thanks: `Darik Gamble`_)
+* Make suggestions for special commands smarter. eg: \dn - only suggests schemas. (Thanks: `Darik Gamble`_)
+* Print out the version and other meta info about pgcli at startup.
+
+Bug Fixes:
+----------
+* Fix a rare crash caused by adding new schemas to a database. (Thanks: `Darik Gamble`_)
+* Make \dt command honor the explicit schema specified in the arg. (Thanks: `Darik Gamble`_)
+* Print BIGSERIAL type as Integer instead of Float.
+* Show completions for special commands at the beginning of a statement. (Thanks: `Daniel Rocco`_)
+* Allow special commands to work in a multi-statement case where multiple sql
+ statements are separated by semi-colon in the same line.
+
+0.15.4
+======
+* Dummy version to replace accidental PyPI entry deletion.
+
+0.15.3
+======
+* Override the LESS options completely instead of appending to it.
+
+0.15.2
+======
+* Revert back to using psycopg2 as the postgres adapter. psycopg2cffi fails for some tests in Python 3.
+
+0.15.0
+======
+
+Features:
+---------
+* Add syntax color styles to config.
+* Add auto-completion for COPY statements.
+* Change Postgres adapter to psycopg2cffi, to make it PyPy compatible.
+ Now pgcli can be run by PyPy.
+
+Bug Fixes:
+----------
+* Treat boolean values as strings instead of ints.
+* Make \di, \dv and \dt to be schema aware. (Thanks: `Darik Gamble`_)
+* Make column name display unicode compatible.
+
+0.14.0
+======
+
+Features:
+---------
+* Add alias completion support to ON keyword. (Thanks: `Irina Truong`_)
+* Add LIMIT keyword to completion.
+* Auto-completion for Postgres schemas. (Thanks: `Darik Gamble`_)
+* Better unicode handling for datatypes, dbname and roles.
+* Add \timing command to time the sql commands.
+ This can be set via config file (~/.pgclirc) using `timing = True`.
+* Add different table styles for displaying output.
+ This can be changed via config file (~/.pgclirc) using `table_format = fancy_grid`.
+* Add confirmation before printing results that have more than 1000 rows.
+
+Bug Fixes:
+----------
+
+* Performance improvements to expanded view display (\x).
+* Cast bytea files to text while displaying. (Thanks: `Daniel Rocco`_)
+* Added a list of reserved words that should be auto-escaped.
+* Auto-completion is now case-insensitive.
+* Fix the broken completion for multiple sql statements. (Thanks: `Darik Gamble`_)
+
+0.13.0
+======
+
+Features:
+---------
+
+* Add -d/--dbname option to the commandline.
+ eg: pgcli -d database
+* Add the username as an argument after the database.
+ eg: pgcli dbname user
+
+Bug Fixes:
+----------
+* Fix the crash when \c fails.
+* Fix the error thrown by \d when triggers are present.
+* Fix broken behavior on \?. (Thanks: `Darik Gamble`_)
+
+0.12.0
+======
+
+Features:
+---------
+
+* Upgrade to prompt_toolkit version 0.26 (Thanks: https://github.com/macobo)
+ * Adds Ctrl-left/right to move the cursor one word left/right respectively.
+ * Internal API changes.
+* IPython integration through `ipython-sql`_ (Thanks: `Darik Gamble`_)
+ * Add an ipython magic extension to embed pgcli inside ipython.
+ * Results from a pgcli query are sent back to ipython.
+* Multiple sql statements in the same line separated by semi-colon. (Thanks: https://github.com/macobo)
+
+.. _`ipython-sql`: https://github.com/catherinedevlin/ipython-sql
+
+Bug Fixes:
+----------
+
+* Fix 'message' attribute not found exception in Python 3. (Thanks: https://github.com/GMLudo)
+* Use the database username as the database name instead of defaulting to OS username. (Thanks: https://github.com/fpietka)
+* Auto-completion for auto-escaped column/table names.
+* Fix i-reverse-search to work in prompt_toolkit version 0.26.
+
+0.11.0
+======
+
+Features:
+---------
+
+* Add \dn command. (Thanks: https://github.com/CyberDem0n)
+* Add \x command. (Thanks: https://github.com/stuartquin)
+* Auto-escape special column/table names. (Thanks: https://github.com/qwesda)
+* Cancel a command using Ctrl+C. (Thanks: https://github.com/macobo)
+* Faster startup by reading all columns and tables in a single query. (Thanks: https://github.com/macobo)
+* Improved psql compliance with env vars and password prompting. (Thanks: `Darik Gamble`_)
+* Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: https://github.com/pabloab).
+
+Bug Fixes:
+----------
+* Fix the broken behavior of \d+. (Thanks: https://github.com/macobo)
+* Fix a crash during auto-completion. (Thanks: https://github.com/Erethon)
+* Avoid losing pre_run_callables on error in editing. (Thanks: https://github.com/catherinedevlin)
+
+Improvements:
+-------------
+* Faster test runs on TravisCI. (Thanks: https://github.com/macobo)
+* Integration tests with Postgres!! (Thanks: https://github.com/macobo)
+
+.. _`Amjith Ramanujam`: https://blog.amjith.com
+.. _`Andrew Kuchling`: https://github.com/akuchling
+.. _`Darik Gamble`: https://github.com/darikg
+.. _`Daniel Rocco`: https://github.com/drocco007
+.. _`Jay Zeng`: https://github.com/jayzeng
+.. _`蔡佳男`: https://github.com/xalley
+.. _dp: https://github.com/ceocoder
+.. _`Jonathan Slenders`: https://github.com/jonathanslenders
+.. _`Dimitar Roustchev`: https://github.com/droustchev
+.. _`François Pietka`: https://github.com/fpietka
+.. _`Ali Kargın`: https://github.com/sancopanco
+.. _`Brett Atoms`: https://github.com/brettatoms
+.. _`Nathan Jhaveri`: https://github.com/nathanjhaveri
+.. _`Çağatay Yüksel`: https://github.com/cagatay
+.. _`Michael Kaminsky`: https://github.com/mikekaminsky
+.. _`inkn`: inkn
+.. _`Johannes Hoff`: Johannes Hoff
+.. _`Matheus Rosa`: Matheus Rosa
+.. _`Artur Dryomov`: https://github.com/ming13
+.. _`Stuart Quin`: https://github.com/stuartquin
+.. _`Tamas Boros`: https://github.com/TamasNo1
+.. _`Jacek Wielemborek`: https://github.com/d33tah
+.. _`Rodrigo Ramírez Norambuena`: https://github.com/roramirez
+.. _`Anthony Lai`: https://github.com/ajlai
+.. _`Ari Summer`: Ari Summer
+.. _`David Szotten`: David Szotten
+.. _`Fernando L. Canizo`: Fernando L. Canizo
+.. _`Tim Sanders`: https://github.com/Gollum999
+.. _`Irina Truong`: https://github.com/j-bennet
+.. _`James Munson`: https://github.com/jmunson
+.. _`Fernando Mora`: https://github.com/fernandomora
+.. _`Timothy Cleaver`: Timothy Cleaver
+.. _`gtxx`: gtxx
+.. _`Joakim Koljonen`: https://github.com/koljonen
+.. _`Anže Pečar`: https://github.com/Smotko
+.. _`Catherine Devlin`: https://github.com/catherinedevlin
+.. _`Eric Wald`: https://github.com/eswald
+.. _`avdd`: https://github.com/avdd
+.. _`Adrian Dries`: Adrian Dries
+.. _`Julien Rouhaud`: https://github.com/rjuju
+.. _`Lim H`: Lim H
+.. _`Tahir Butt`: Tahir Butt
+.. _`Tiziano Müller`: https://github.com/dev-zero
+.. _`Janus Troelsen`: https://github.com/ysangkok
+.. _`Fabien Meghazi`: https://github.com/amigrave
+.. _`Manuel Barkhau`: https://github.com/mbarkhau
+.. _`Sergii`: https://github.com/foxyterkel
+.. _`Lele Gaifax`: https://github.com/lelit
+.. _`tk`: https://github.com/kanet77
+.. _`Owen Stephens`: https://github.com/owst
+.. _`Russell Davies`: https://github.com/russelldavies
+.. _`Dick Marinus`: https://github.com/meeuw
+.. _`Étienne Bersac`: https://github.com/bersace
+.. _`Thomas Roten`: https://github.com/tsroten
+.. _`Gustavo Castro`: https://github.com/gustavo-castro
+.. _`Alexander Schmolck`: https://github.com/aschmolck
+.. _`Andrew Speed`: https://github.com/AndrewSpeed
+.. _`Dmitry B`: https://github.com/oxitnik
+.. _`Marcin Sztolcman`: https://github.com/msztolcman
+.. _`Isank`: https://github.com/isank
+.. _`Bojan Delić`: https://github.com/delicb
+.. _`Frederic Aoustin`: https://github.com/fraoustin
+.. _`Jason Ribeiro`: https://github.com/jrib
+.. _`Rishi Ramraj`: https://github.com/RishiRamraj
+.. _`Matthieu Guilbert`: https://github.com/gma2th
+.. _`Alexandr Korsak`: https://github.com/oivoodoo
+.. _`Saif Hakim`: https://github.com/saifelse
+.. _`Artur Balabanov`: https://github.com/arturbalabanov
+.. _`Kenny Do`: https://github.com/kennydo
+.. _`Max Rothman`: https://github.com/maxrothman
+.. _`Daniel Egger`: https://github.com/DanEEStar
+.. _`Ignacio Campabadal`: https://github.com/igncampa
+.. _`Mikhail Elovskikh`: https://github.com/wronglink
+.. _`Marcin Cieślak`: https://github.com/saper
+.. _`Scott Brenstuhl`: https://github.com/808sAndBR
+.. _`easteregg`: https://github.com/verfriemelt-dot-org
+.. _`Nathan Verzemnieks`: https://github.com/njvrzm
+.. _`raylu`: https://github.com/raylu
+.. _`Zhaolong Zhu`: https://github.com/zzl0
+.. _`Xavier Francisco`: https://github.com/Qu4tro
+.. _`VVelox`: https://github.com/VVelox
+.. _`Telmo "Trooper"`: https://github.com/telmotrooper
+.. _`Alexander Zawadzki`: https://github.com/zadacka
+.. _`Sebastian Janko`: https://github.com/sebojanko
+.. _`Pedro Ferrari`: https://github.com/petobens
+.. _`BrownShibaDog`: https://github.com/BrownShibaDog
+.. _`thegeorgeous`: https://github.com/thegeorgeous
+.. _`laixintao`: https://github.com/laixintao
+.. _`anthonydb`: https://github.com/anthonydb
+.. _`Daniel Kukula`: https://github.com/dkuku
diff --git a/pgcli-completion.bash b/pgcli-completion.bash
new file mode 100644
index 0000000..3549b56
--- /dev/null
+++ b/pgcli-completion.bash
@@ -0,0 +1,61 @@
+_pg_databases()
+{
+ # -w was introduced in 8.4, https://launchpad.net/bugs/164772
+ # "Access privileges" in output may contain linefeeds, hence the NF > 1
+ COMPREPLY=( $( compgen -W "$( psql -AtqwlF $'\t' 2>/dev/null | \
+ awk 'NF > 1 { print $1 }' )" -- "$cur" ) )
+}
+
+_pg_users()
+{
+ # -w was introduced in 8.4, https://launchpad.net/bugs/164772
+ COMPREPLY=( $( compgen -W "$( psql -Atqwc 'select usename from pg_user' \
+ template1 2>/dev/null )" -- "$cur" ) )
+ [[ ${#COMPREPLY[@]} -eq 0 ]] && COMPREPLY=( $( compgen -u -- "$cur" ) )
+}
+
+_pgcli()
+{
+ local cur prev words cword
+ _init_completion -s || return
+
+ case $prev in
+ -h|--host)
+ _known_hosts_real "$cur"
+ return 0
+ ;;
+ -U|--user)
+ _pg_users
+ return 0
+ ;;
+ -d|--dbname)
+ _pg_databases
+ return 0
+ ;;
+ --help|-v|--version|-p|--port|-R|--row-limit)
+ # all other arguments are noop with these
+ return 0
+ ;;
+ esac
+
+ case "$cur" in
+ --*)
+ # return list of available options
+ COMPREPLY=( $( compgen -W '--host --port --user --password --no-password
+ --single-connection --version --dbname --pgclirc --dsn
+ --row-limit --help' -- "$cur" ) )
+ [[ $COMPREPLY == *= ]] && compopt -o nospace
+ return 0
+ ;;
+ -)
+ # only complete long options
+ compopt -o nospace
+ COMPREPLY=( -- )
+ return 0
+ ;;
+ *)
+ # return list of available databases
+ _pg_databases
+ esac
+} &&
+complete -F _pgcli pgcli
diff --git a/pgcli/__init__.py b/pgcli/__init__.py
new file mode 100644
index 0000000..76ad18b
--- /dev/null
+++ b/pgcli/__init__.py
@@ -0,0 +1 @@
+__version__ = "4.0.1"
diff --git a/pgcli/__main__.py b/pgcli/__main__.py
new file mode 100644
index 0000000..ddf1662
--- /dev/null
+++ b/pgcli/__main__.py
@@ -0,0 +1,9 @@
+"""
+pgcli package main entry point
+"""
+
+from .main import cli
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/pgcli/auth.py b/pgcli/auth.py
new file mode 100644
index 0000000..2f1e552
--- /dev/null
+++ b/pgcli/auth.py
@@ -0,0 +1,60 @@
+import click
+from textwrap import dedent
+
+
+keyring = None # keyring will be loaded later
+
+
+keyring_error_message = dedent(
+ """\
+ {}
+ {}
+ To remove this message do one of the following:
+ - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
+ - uninstall keyring: pip uninstall keyring
+ - disable keyring in our configuration: add keyring = False to [main]"""
+)
+
+
+def keyring_initialize(keyring_enabled, *, logger):
+ """Initialize keyring only if explicitly enabled"""
+ global keyring
+
+ if keyring_enabled:
+ # Try best to load keyring (issue #1041).
+ import importlib
+
+ try:
+ keyring = importlib.import_module("keyring")
+ except (
+ ModuleNotFoundError
+ ) as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
+ logger.warning("import keyring failed: %r.", e)
+
+
+def keyring_get_password(key):
+ """Attempt to get password from keyring"""
+ # Find password from store
+ passwd = ""
+ try:
+ passwd = keyring.get_password("pgcli", key) or ""
+ except Exception as e:
+ click.secho(
+ keyring_error_message.format(
+ "Load your password from keyring returned:", str(e)
+ ),
+ err=True,
+ fg="red",
+ )
+ return passwd
+
+
+def keyring_set_password(key, passwd):
+ try:
+ keyring.set_password("pgcli", key, passwd)
+ except Exception as e:
+ click.secho(
+ keyring_error_message.format("Set password in keyring returned:", str(e)),
+ err=True,
+ fg="red",
+ )
diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py
new file mode 100644
index 0000000..c887cb6
--- /dev/null
+++ b/pgcli/completion_refresher.py
@@ -0,0 +1,152 @@
+import threading
+import os
+from collections import OrderedDict
+
+from .pgcompleter import PGCompleter
+
+
+class CompletionRefresher:
+ refreshers = OrderedDict()
+
+ def __init__(self):
+ self._completer_thread = None
+ self._restart_refresh = threading.Event()
+
+ def refresh(self, executor, special, callbacks, history=None, settings=None):
+ """
+ Creates a PGCompleter object and populates it with the relevant
+ completion suggestions in a background thread.
+
+ executor - PGExecute object, used to extract the credentials to connect
+ to the database.
+ special - PGSpecial object used for creating a new completion object.
+ settings - dict of settings for completer object
+ callbacks - A function or a list of functions to call after the thread
+ has completed the refresh. The newly created completion
+ object will be passed in as an argument to each callback.
+ """
+ if executor.is_virtual_database():
+ # do nothing
+ return [(None, None, None, "Auto-completion refresh can't be started.")]
+
+ if self.is_refreshing():
+ self._restart_refresh.set()
+ return [(None, None, None, "Auto-completion refresh restarted.")]
+ else:
+ self._completer_thread = threading.Thread(
+ target=self._bg_refresh,
+ args=(executor, special, callbacks, history, settings),
+ name="completion_refresh",
+ )
+ self._completer_thread.daemon = True
+ self._completer_thread.start()
+ return [
+ (None, None, None, "Auto-completion refresh started in the background.")
+ ]
+
+ def is_refreshing(self):
+ return self._completer_thread and self._completer_thread.is_alive()
+
+ def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None):
+ settings = settings or {}
+ completer = PGCompleter(
+ smart_completion=True, pgspecial=special, settings=settings
+ )
+
+ if settings.get("single_connection"):
+ executor = pgexecute
+ else:
+ # Create a new pgexecute method to populate the completions.
+ executor = pgexecute.copy()
+ # If callbacks is a single function then push it into a list.
+ if callable(callbacks):
+ callbacks = [callbacks]
+
+ while 1:
+ for refresher in self.refreshers.values():
+ refresher(completer, executor)
+ if self._restart_refresh.is_set():
+ self._restart_refresh.clear()
+ break
+ else:
+ # Break out of while loop if the for loop finishes natually
+ # without hitting the break statement.
+ break
+
+ # Start over the refresh from the beginning if the for loop hit the
+ # break statement.
+ continue
+
+ # Load history into pgcompleter so it can learn user preferences
+ n_recent = 100
+ if history:
+ for recent in history.get_strings()[-n_recent:]:
+ completer.extend_query_history(recent, is_init=True)
+
+ for callback in callbacks:
+ callback(completer)
+
+ if not settings.get("single_connection") and executor.conn:
+ # close connection established with pgexecute.copy()
+ executor.conn.close()
+
+
+def refresher(name, refreshers=CompletionRefresher.refreshers):
+ """Decorator to populate the dictionary of refreshers with the current
+ function.
+ """
+
+ def wrapper(wrapped):
+ refreshers[name] = wrapped
+ return wrapped
+
+ return wrapper
+
+
+@refresher("schemata")
+def refresh_schemata(completer, executor):
+ completer.set_search_path(executor.search_path())
+ completer.extend_schemata(executor.schemata())
+
+
+@refresher("tables")
+def refresh_tables(completer, executor):
+ completer.extend_relations(executor.tables(), kind="tables")
+ completer.extend_columns(executor.table_columns(), kind="tables")
+ completer.extend_foreignkeys(executor.foreignkeys())
+
+
+@refresher("views")
+def refresh_views(completer, executor):
+ completer.extend_relations(executor.views(), kind="views")
+ completer.extend_columns(executor.view_columns(), kind="views")
+
+
+@refresher("types")
+def refresh_types(completer, executor):
+ completer.extend_datatypes(executor.datatypes())
+
+
+@refresher("databases")
+def refresh_databases(completer, executor):
+ completer.extend_database_names(executor.databases())
+
+
+@refresher("casing")
+def refresh_casing(completer, executor):
+ casing_file = completer.casing_file
+ if not casing_file:
+ return
+ generate_casing_file = completer.generate_casing_file
+ if generate_casing_file and not os.path.isfile(casing_file):
+ casing_prefs = "\n".join(executor.casing())
+ with open(casing_file, "w") as f:
+ f.write(casing_prefs)
+ if os.path.isfile(casing_file):
+ with open(casing_file) as f:
+ completer.extend_casing([line.strip() for line in f])
+
+
+@refresher("functions")
+def refresh_functions(completer, executor):
+ completer.extend_functions(executor.functions())
diff --git a/pgcli/config.py b/pgcli/config.py
new file mode 100644
index 0000000..22f08dc
--- /dev/null
+++ b/pgcli/config.py
@@ -0,0 +1,99 @@
+import errno
+import shutil
+import os
+import platform
+from os.path import expanduser, exists, dirname
+import re
+from typing import TextIO
+from configobj import ConfigObj
+
+
+def config_location():
+ if "XDG_CONFIG_HOME" in os.environ:
+ return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"])
+ elif platform.system() == "Windows":
+ return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\"
+ else:
+ return expanduser("~/.config/pgcli/")
+
+
+def load_config(usr_cfg, def_cfg=None):
+ # avoid config merges when possible. For writing, we need an umerged config instance.
+ # see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171
+ if def_cfg:
+ cfg = ConfigObj()
+ cfg.merge(ConfigObj(def_cfg, interpolation=False))
+ cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ else:
+ cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")
+ cfg.filename = expanduser(usr_cfg)
+ return cfg
+
+
+def ensure_dir_exists(path):
+ parent_dir = expanduser(dirname(path))
+ os.makedirs(parent_dir, exist_ok=True)
+
+
+def write_default_config(source, destination, overwrite=False):
+ destination = expanduser(destination)
+ if not overwrite and exists(destination):
+ return
+
+ ensure_dir_exists(destination)
+
+ shutil.copyfile(source, destination)
+
+
+def upgrade_config(config, def_config):
+ cfg = load_config(config, def_config)
+ cfg.write()
+
+
+def get_config_filename(pgclirc_file=None):
+ return pgclirc_file or "%sconfig" % config_location()
+
+
+def get_config(pgclirc_file=None):
+ from pgcli import __file__ as package_root
+
+ package_root = os.path.dirname(package_root)
+
+ pgclirc_file = get_config_filename(pgclirc_file)
+
+ default_config = os.path.join(package_root, "pgclirc")
+ write_default_config(default_config, pgclirc_file)
+
+ return load_config(pgclirc_file, default_config)
+
+
+def get_casing_file(config):
+ casing_file = config["main"]["casing_file"]
+ if casing_file == "default":
+ casing_file = config_location() + "casing"
+ return casing_file
+
+
+def skip_initial_comment(f_stream: TextIO) -> int:
+ """
+ Initial comment in ~/.pg_service.conf is not always marked with '#'
+ which crashes the parser. This function takes a file object and
+ "rewinds" it to the beginning of the first section,
+ from where on it can be parsed safely
+
+ :return: number of skipped lines
+ """
+ section_regex = r"\s*\["
+ pos = f_stream.tell()
+ lines_skipped = 0
+ while True:
+ line = f_stream.readline()
+ if line == "":
+ break
+ if re.match(section_regex, line) is not None:
+ f_stream.seek(pos)
+ break
+ else:
+ pos += len(line)
+ lines_skipped += 1
+ return lines_skipped
diff --git a/pgcli/explain_output_formatter.py b/pgcli/explain_output_formatter.py
new file mode 100644
index 0000000..ce45b4f
--- /dev/null
+++ b/pgcli/explain_output_formatter.py
@@ -0,0 +1,19 @@
+from pgcli.pyev import Visualizer
+import json
+
+
+"""Explain response output adapter"""
+
+
+class ExplainOutputFormatter:
+ def __init__(self, max_width):
+ self.max_width = max_width
+
+ def format_output(self, cur, headers, **output_kwargs):
+ # explain query results should always contain 1 row each
+ [(data,)] = list(cur)
+ explain_list = json.loads(data)
+ visualizer = Visualizer(self.max_width)
+ for explain in explain_list:
+ visualizer.load(explain)
+ yield visualizer.get_list()
diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py
new file mode 100644
index 0000000..9c016f7
--- /dev/null
+++ b/pgcli/key_bindings.py
@@ -0,0 +1,133 @@
+import logging
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.key_binding import KeyBindings
+from prompt_toolkit.filters import (
+ completion_is_selected,
+ is_searching,
+ has_completions,
+ has_selection,
+ vi_mode,
+)
+
+from .pgbuffer import buffer_should_be_handled, safe_multi_line_mode
+
+_logger = logging.getLogger(__name__)
+
+
+def pgcli_bindings(pgcli):
+ """Custom key bindings for pgcli."""
+ kb = KeyBindings()
+
+ tab_insert_text = " " * 4
+
+ @kb.add("f2")
+ def _(event):
+ """Enable/Disable SmartCompletion Mode."""
+ _logger.debug("Detected F2 key.")
+ pgcli.completer.smart_completion = not pgcli.completer.smart_completion
+
+ @kb.add("f3")
+ def _(event):
+ """Enable/Disable Multiline Mode."""
+ _logger.debug("Detected F3 key.")
+ pgcli.multi_line = not pgcli.multi_line
+
+ @kb.add("f4")
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug("Detected F4 key.")
+ pgcli.vi_mode = not pgcli.vi_mode
+ event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS
+
+ @kb.add("f5")
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug("Detected F5 key.")
+ pgcli.explain_mode = not pgcli.explain_mode
+
+ @kb.add("tab")
+ def _(event):
+ """Force autocompletion at cursor on non-empty lines."""
+
+ _logger.debug("Detected <Tab> key.")
+
+ buff = event.app.current_buffer
+ doc = buff.document
+
+ if doc.on_first_line or doc.current_line.strip():
+ if buff.complete_state:
+ buff.complete_next()
+ else:
+ buff.start_completion(select_first=True)
+ else:
+ buff.insert_text(tab_insert_text, fire_event=False)
+
+ @kb.add("escape", filter=has_completions)
+ def _(event):
+ """Force closing of autocompletion."""
+ _logger.debug("Detected <Esc> key.")
+
+ event.current_buffer.complete_state = None
+ event.app.current_buffer.complete_state = None
+
+ @kb.add("c-space")
+ def _(event):
+ """
+ Initialize autocompletion at cursor.
+
+ If the autocompletion menu is not showing, display it with the
+ appropriate completions for the context.
+
+ If the menu is showing, select the next completion.
+ """
+ _logger.debug("Detected <C-Space> key.")
+
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=False)
+
+ @kb.add("enter", filter=completion_is_selected)
+ def _(event):
+ """Makes the enter key work as the tab key only when showing the menu.
+
+ In other words, don't execute query when enter is pressed in
+ the completion dropdown menu, instead close the dropdown menu
+ (accept current selection).
+
+ """
+ _logger.debug("Detected enter key during completion selection.")
+
+ event.current_buffer.complete_state = None
+ event.app.current_buffer.complete_state = None
+
+ # When using multi_line input mode the buffer is not handled on Enter (a new line is
+ # inserted instead), so we force the handling if we're not in a completion or
+ # history search, and one of several conditions are True
+ @kb.add(
+ "enter",
+ filter=~(completion_is_selected | is_searching)
+ & buffer_should_be_handled(pgcli),
+ )
+ def _(event):
+ _logger.debug("Detected enter key.")
+ event.current_buffer.validate_and_handle()
+
+ @kb.add("escape", "enter", filter=~vi_mode & ~safe_multi_line_mode(pgcli))
+ def _(event):
+ """Introduces a line break regardless of multi-line mode or not."""
+ _logger.debug("Detected alt-enter key.")
+ event.app.current_buffer.insert_text("\n")
+
+ @kb.add("c-p", filter=~has_selection)
+ def _(event):
+ """Move up in history."""
+ event.current_buffer.history_backward(count=event.arg)
+
+ @kb.add("c-n", filter=~has_selection)
+ def _(event):
+ """Move down in history."""
+ event.current_buffer.history_forward(count=event.arg)
+
+ return kb
diff --git a/pgcli/magic.py b/pgcli/magic.py
new file mode 100644
index 0000000..09902a2
--- /dev/null
+++ b/pgcli/magic.py
@@ -0,0 +1,71 @@
+from .main import PGCli
+import sql.parse
+import sql.connection
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+def load_ipython_extension(ipython):
+ """This is called via the ipython command '%load_ext pgcli.magic'"""
+
+ # first, load the sql magic if it isn't already loaded
+ if not ipython.find_line_magic("sql"):
+ ipython.run_line_magic("load_ext", "sql")
+
+ # register our own magic
+ ipython.register_magic_function(pgcli_line_magic, "line", "pgcli")
+
+
+def pgcli_line_magic(line):
+ _logger.debug("pgcli magic called: %r", line)
+ parsed = sql.parse.parse(line, {})
+ # "get" was renamed to "set" in ipython-sql:
+ # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43
+ if hasattr(sql.connection.Connection, "get"):
+ conn = sql.connection.Connection.get(parsed["connection"])
+ else:
+ try:
+ conn = sql.connection.Connection.set(parsed["connection"])
+ # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql
+ except TypeError:
+ conn = sql.connection.Connection.set(parsed["connection"], False)
+
+ try:
+ # A corresponding pgcli object already exists
+ pgcli = conn._pgcli
+ _logger.debug("Reusing existing pgcli")
+ except AttributeError:
+ # I can't figure out how to get the underylying psycopg2 connection
+ # from the sqlalchemy connection, so just grab the url and make a
+ # new connection
+ pgcli = PGCli()
+ u = conn.session.engine.url
+ _logger.debug("New pgcli: %r", str(u))
+
+ pgcli.connect_uri(str(u._replace(drivername="postgres")))
+ conn._pgcli = pgcli
+
+ # For convenience, print the connection alias
+ print(f"Connected: {conn.name}")
+
+ try:
+ pgcli.run_cli()
+ except SystemExit:
+ pass
+
+ if not pgcli.query_history:
+ return
+
+ q = pgcli.query_history[-1]
+
+ if not q.successful:
+ _logger.debug("Unsuccessful query - ignoring")
+ return
+
+ if q.meta_changed or q.db_changed or q.path_changed:
+ _logger.debug("Dangerous query detected -- ignoring")
+ return
+
+ ipython = get_ipython()
+ return ipython.run_cell_magic("sql", line, q.query)
diff --git a/pgcli/main.py b/pgcli/main.py
new file mode 100644
index 0000000..f95c800
--- /dev/null
+++ b/pgcli/main.py
@@ -0,0 +1,1728 @@
+from configobj import ConfigObj, ParseError
+from pgspecial.namedqueries import NamedQueries
+from .config import skip_initial_comment
+
+import atexit
+import os
+import re
+import sys
+import traceback
+import logging
+import threading
+import shutil
+import functools
+import pendulum
+import datetime as dt
+import itertools
+import platform
+from time import time, sleep
+from typing import Optional
+
+from cli_helpers.tabular_output import TabularOutputFormatter
+from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
+from cli_helpers.utils import strip_ansi
+from .explain_output_formatter import ExplainOutputFormatter
+import click
+
+try:
+ import setproctitle
+except ImportError:
+ setproctitle = None
+from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter
+from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
+from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
+from prompt_toolkit.document import Document
+from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.formatted_text import ANSI
+from prompt_toolkit.lexers import PygmentsLexer
+from prompt_toolkit.layout.processors import (
+ ConditionalProcessor,
+ HighlightMatchingBracketProcessor,
+ TabsProcessor,
+)
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+from pygments.lexers.sql import PostgresLexer
+
+from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
+import pgspecial as special
+
+from . import auth
+from .pgcompleter import PGCompleter
+from .pgtoolbar import create_toolbar_tokens_func
+from .pgstyle import style_factory, style_factory_output
+from .pgexecute import PGExecute
+from .completion_refresher import CompletionRefresher
+from .config import (
+ get_casing_file,
+ load_config,
+ config_location,
+ ensure_dir_exists,
+ get_config,
+ get_config_filename,
+)
+from .key_bindings import pgcli_bindings
+from .packages.formatter.sqlformatter import register_new_formatter
+from .packages.prompt_utils import confirm, confirm_destructive_query
+from .packages.parseutils import is_destructive
+from .packages.parseutils import parse_destructive_warning
+from .__init__ import __version__
+
+click.disable_unicode_literals_warning = True
+
+from urllib.parse import urlparse
+
+from getpass import getuser
+
+from psycopg import OperationalError, InterfaceError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict
+
+from collections import namedtuple
+
+try:
+ import sshtunnel
+
+ SSH_TUNNEL_SUPPORT = True
+except ImportError:
+ SSH_TUNNEL_SUPPORT = False
+
+
+# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
+COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
+DEFAULT_MAX_FIELD_WIDTH = 500
+
+# Query tuples are used for maintaining history
+MetaQuery = namedtuple(
+ "Query",
+ [
+ "query", # The entire text of the command
+ "successful", # True If all subqueries were successful
+ "total_time", # Time elapsed executing the query and formatting results
+ "execution_time", # Time elapsed executing the query
+ "meta_changed", # True if any subquery executed create/alter/drop
+ "db_changed", # True if any subquery changed the database
+ "path_changed", # True if any subquery changed the search path
+ "mutated", # True if any subquery executed insert/update/delete
+ "is_special", # True if the query is a special command
+ ],
+)
+MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False)
+
+OutputSettings = namedtuple(
+ "OutputSettings",
+ "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output max_field_width",
+)
+OutputSettings.__new__.__defaults__ = (
+ None,
+ None,
+ None,
+ "<null>",
+ False,
+ None,
+ lambda x: x,
+ None,
+ DEFAULT_MAX_FIELD_WIDTH,
+)
+
+
+class PgCliQuitError(Exception):
+ pass
+
+
+class PGCli:
+ default_prompt = "\\u@\\h:\\d> "
+ max_len_prompt = 30
+
+ def set_default_pager(self, config):
+ configured_pager = config["main"].get("pager")
+ os_environ_pager = os.environ.get("PAGER")
+
+ if configured_pager:
+ self.logger.info(
+ 'Default pager found in config file: "%s"', configured_pager
+ )
+ os.environ["PAGER"] = configured_pager
+ elif os_environ_pager:
+ self.logger.info(
+ 'Default pager found in PAGER environment variable: "%s"',
+ os_environ_pager,
+ )
+ os.environ["PAGER"] = os_environ_pager
+ else:
+ self.logger.info(
+ "No default pager found in environment. Using os default pager"
+ )
+
+ # Set default set of less recommended options, if they are not already set.
+ # They are ignored if pager is different than less.
+ if not os.environ.get("LESS"):
+ os.environ["LESS"] = "-SRXF"
+
+ def __init__(
+ self,
+ force_passwd_prompt=False,
+ never_passwd_prompt=False,
+ pgexecute=None,
+ pgclirc_file=None,
+ row_limit=None,
+ single_connection=False,
+ less_chatty=None,
+ prompt=None,
+ prompt_dsn=None,
+ auto_vertical_output=False,
+ warn=None,
+ ssh_tunnel_url: Optional[str] = None,
+ ):
+ self.force_passwd_prompt = force_passwd_prompt
+ self.never_passwd_prompt = never_passwd_prompt
+ self.pgexecute = pgexecute
+ self.dsn_alias = None
+ self.watch_command = None
+
+ # Load config.
+ c = self.config = get_config(pgclirc_file)
+
+ # at this point, config should be written to pgclirc_file if it did not exist. Read it.
+ self.config_writer = load_config(get_config_filename(pgclirc_file))
+
+ # make sure to use self.config_writer, not self.config
+ NamedQueries.instance = NamedQueries.from_config(self.config_writer)
+
+ self.logger = logging.getLogger(__name__)
+ self.initialize_logging()
+
+ self.set_default_pager(c)
+ self.output_file = None
+ self.pgspecial = PGSpecial()
+
+ self.explain_mode = False
+ self.multi_line = c["main"].as_bool("multi_line")
+ self.multiline_mode = c["main"].get("multi_line_mode", "psql")
+ self.vi_mode = c["main"].as_bool("vi")
+ self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
+ self.auto_retry_closed_connection = c["main"].as_bool(
+ "auto_retry_closed_connection"
+ )
+ self.expanded_output = c["main"].as_bool("expand")
+ self.pgspecial.timing_enabled = c["main"].as_bool("timing")
+ if row_limit is not None:
+ self.row_limit = row_limit
+ else:
+ self.row_limit = c["main"].as_int("row_limit")
+
+ # if not specified, set to DEFAULT_MAX_FIELD_WIDTH
+ # if specified but empty, set to None to disable truncation
+ # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0
+ max_field_width = c["main"].get("max_field_width", DEFAULT_MAX_FIELD_WIDTH)
+ if max_field_width and max_field_width.lower() != "none":
+ max_field_width = max(3, abs(int(max_field_width)))
+ else:
+ max_field_width = None
+ self.max_field_width = max_field_width
+
+ self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
+ self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
+ self.table_format = c["main"]["table_format"]
+ self.syntax_style = c["main"]["syntax_style"]
+ self.cli_style = c["colors"]
+ self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
+ self.destructive_warning = parse_destructive_warning(
+ warn or c["main"].as_list("destructive_warning")
+ )
+ self.destructive_warning_restarts_connection = c["main"].as_bool(
+ "destructive_warning_restarts_connection"
+ )
+ self.destructive_statements_require_transaction = c["main"].as_bool(
+ "destructive_statements_require_transaction"
+ )
+
+ self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
+ self.null_string = c["main"].get("null_string", "<null>")
+ self.prompt_format = (
+ prompt
+ if prompt is not None
+ else c["main"].get("prompt", self.default_prompt)
+ )
+ self.prompt_dsn_format = prompt_dsn
+ self.on_error = c["main"]["on_error"].upper()
+ self.decimal_format = c["data_formats"]["decimal"]
+ self.float_format = c["data_formats"]["float"]
+ auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger)
+ self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
+
+ self.pgspecial.pset_pager(
+ self.config["main"].as_bool("enable_pager") and "on" or "off"
+ )
+
+ self.style_output = style_factory_output(self.syntax_style, c["colors"])
+
+ self.now = dt.datetime.today()
+
+ self.completion_refresher = CompletionRefresher()
+
+ self.query_history = []
+
+ # Initialize completer
+ smart_completion = c["main"].as_bool("smart_completion")
+ keyword_casing = c["main"]["keyword_casing"]
+ single_connection = single_connection or c["main"].as_bool(
+ "always_use_single_connection"
+ )
+ self.settings = {
+ "casing_file": get_casing_file(c),
+ "generate_casing_file": c["main"].as_bool("generate_casing_file"),
+ "generate_aliases": c["main"].as_bool("generate_aliases"),
+ "asterisk_column_order": c["main"]["asterisk_column_order"],
+ "qualify_columns": c["main"]["qualify_columns"],
+ "case_column_headers": c["main"].as_bool("case_column_headers"),
+ "search_path_filter": c["main"].as_bool("search_path_filter"),
+ "single_connection": single_connection,
+ "less_chatty": less_chatty,
+ "keyword_casing": keyword_casing,
+ "alias_map_file": c["main"]["alias_map_file"] or None,
+ }
+
+ completer = PGCompleter(
+ smart_completion, pgspecial=self.pgspecial, settings=self.settings
+ )
+ self.completer = completer
+ self._completer_lock = threading.Lock()
+ self.register_special_commands()
+
+ self.prompt_app = None
+
+ self.ssh_tunnel_config = c.get("ssh tunnels")
+ self.ssh_tunnel_url = ssh_tunnel_url
+ self.ssh_tunnel = None
+
+ # formatter setup
+ self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
+ register_new_formatter(self.formatter)
+
+ def quit(self):
+ raise PgCliQuitError
+
+ def register_special_commands(self):
+ self.pgspecial.register(
+ self.change_db,
+ "\\c",
+ "\\c[onnect] database_name",
+ "Change to a new database.",
+ aliases=("use", "\\connect", "USE"),
+ )
+
+ refresh_callback = lambda: self.refresh_completions(persist_priorities="all")
+
+ self.pgspecial.register(
+ self.quit,
+ "\\q",
+ "\\q",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+ aliases=(":q",),
+ )
+ self.pgspecial.register(
+ self.quit,
+ "quit",
+ "quit",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=False,
+ aliases=("exit",),
+ )
+ self.pgspecial.register(
+ refresh_callback,
+ "\\#",
+ "\\#",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ )
+ self.pgspecial.register(
+ refresh_callback,
+ "\\refresh",
+ "\\refresh",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ )
+ self.pgspecial.register(
+ self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
+ )
+ self.pgspecial.register(
+ self.write_to_file,
+ "\\o",
+ "\\o [filename]",
+ "Send all query results to file.",
+ )
+ self.pgspecial.register(
+ self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
+ )
+ self.pgspecial.register(
+ self.change_table_format,
+ "\\T",
+ "\\T [format]",
+ "Change the table format used to output results",
+ )
+
+ self.pgspecial.register(
+ self.echo,
+ "\\echo",
+ "\\echo [string]",
+ "Echo a string to stdout",
+ )
+
+ self.pgspecial.register(
+ self.echo,
+ "\\qecho",
+ "\\qecho [string]",
+ "Echo a string to the query output channel.",
+ )
+
+ def echo(self, pattern, **_):
+ return [(None, None, None, pattern)]
+
+ def change_table_format(self, pattern, **_):
+ try:
+ if pattern not in TabularOutputFormatter().supported_formats:
+ raise ValueError()
+ self.table_format = pattern
+ yield (None, None, None, f"Changed table format to {pattern}")
+ except ValueError:
+ msg = f"Table format {pattern} not recognized. Allowed formats:"
+ for table_type in TabularOutputFormatter().supported_formats:
+ msg += f"\n\t{table_type}"
+ msg += "\nCurrently set to: %s" % self.table_format
+ yield (None, None, None, msg)
+
+ def info_connection(self, **_):
+ if self.pgexecute.host.startswith("/"):
+ host = 'socket "%s"' % self.pgexecute.host
+ else:
+ host = 'host "%s"' % self.pgexecute.host
+
+ yield (
+ None,
+ None,
+ None,
+ 'You are connected to database "%s" as user '
+ '"%s" on %s at port "%s".'
+ % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
+ )
+
+ def change_db(self, pattern, **_):
+ if pattern:
+ # Get all the parameters in pattern, handling double quotes if any.
+ infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
+ # Now removing quotes.
+ list(map(lambda s: s.strip('"'), infos))
+
+ infos.extend([None] * (4 - len(infos)))
+ db, user, host, port = infos
+ try:
+ self.pgexecute.connect(
+ database=db,
+ user=user,
+ host=host,
+ port=port,
+ **self.pgexecute.extra_args,
+ )
+ except OperationalError as e:
+ click.secho(str(e), err=True, fg="red")
+ click.echo("Previous connection kept")
+ else:
+ self.pgexecute.connect()
+
+ yield (
+ None,
+ None,
+ None,
+ 'You are now connected to database "%s" as '
+ 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
+ )
+
+ def execute_from_file(self, pattern, **_):
+ if not pattern:
+ message = "\\i: missing required argument"
+ return [(None, None, None, message, "", False, True)]
+ try:
+ with open(os.path.expanduser(pattern), encoding="utf-8") as f:
+ query = f.read()
+ except OSError as e:
+ return [(None, None, None, str(e), "", False, True)]
+
+ if self.destructive_warning:
+ if (
+ self.destructive_statements_require_transaction
+ and not self.pgexecute.valid_transaction()
+ and is_destructive(query, self.destructive_warning)
+ ):
+ message = "Destructive statements must be run within a transaction. Command execution stopped."
+ return [(None, None, None, message)]
+ destroy = confirm_destructive_query(
+ query, self.destructive_warning, self.dsn_alias
+ )
+ if destroy is False:
+ message = "Wise choice. Command execution stopped."
+ return [(None, None, None, message)]
+
+ on_error_resume = self.on_error == "RESUME"
+ return self.pgexecute.run(
+ query,
+ self.pgspecial,
+ on_error_resume=on_error_resume,
+ explain_mode=self.explain_mode,
+ )
+
+ def write_to_file(self, pattern, **_):
+ if not pattern:
+ self.output_file = None
+ message = "File output disabled"
+ return [(None, None, None, message, "", True, True)]
+ filename = os.path.abspath(os.path.expanduser(pattern))
+ if not os.path.isfile(filename):
+ try:
+ open(filename, "w").close()
+ except OSError as e:
+ self.output_file = None
+ message = str(e) + "\nFile output disabled"
+ return [(None, None, None, message, "", False, True)]
+ self.output_file = filename
+ message = 'Writing to file "%s"' % self.output_file
+ return [(None, None, None, message, "", True, True)]
+
+ def initialize_logging(self):
+ log_file = self.config["main"]["log_file"]
+ if log_file == "default":
+ log_file = config_location() + "log"
+ ensure_dir_exists(log_file)
+ log_level = self.config["main"]["log_level"]
+
+ # Disable logging if value is NONE by switching to a no-op handler.
+ # Set log level to a high value so it doesn't even waste cycles getting called.
+ if log_level.upper() == "NONE":
+ handler = logging.NullHandler()
+ else:
+ handler = logging.FileHandler(os.path.expanduser(log_file))
+
+ level_map = {
+ "CRITICAL": logging.CRITICAL,
+ "ERROR": logging.ERROR,
+ "WARNING": logging.WARNING,
+ "INFO": logging.INFO,
+ "DEBUG": logging.DEBUG,
+ "NONE": logging.CRITICAL,
+ }
+
+ log_level = level_map[log_level.upper()]
+
+ formatter = logging.Formatter(
+ "%(asctime)s (%(process)d/%(threadName)s) "
+ "%(name)s %(levelname)s - %(message)s"
+ )
+
+ handler.setFormatter(formatter)
+
+ root_logger = logging.getLogger("pgcli")
+ root_logger.addHandler(handler)
+ root_logger.setLevel(log_level)
+
+ root_logger.debug("Initializing pgcli logging.")
+ root_logger.debug("Log file %r.", log_file)
+
+ pgspecial_logger = logging.getLogger("pgspecial")
+ pgspecial_logger.addHandler(handler)
+ pgspecial_logger.setLevel(log_level)
+
+ def connect_dsn(self, dsn, **kwargs):
+ self.connect(dsn=dsn, **kwargs)
+
+ def connect_service(self, service, user):
+ service_config, file = parse_service_info(service)
+ if service_config is None:
+ click.secho(
+ f"service '{service}' was not found in {file}", err=True, fg="red"
+ )
+ exit(1)
+ self.connect(
+ database=service_config.get("dbname"),
+ host=service_config.get("host"),
+ user=user or service_config.get("user"),
+ port=service_config.get("port"),
+ passwd=service_config.get("password"),
+ )
+
+ def connect_uri(self, uri):
+ kwargs = conninfo_to_dict(uri)
+ remap = {"dbname": "database", "password": "passwd"}
+ kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
+ self.connect(**kwargs)
+
+ def connect(
+ self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
+ ):
+ # Connect to the database.
+
+ if not user:
+ user = getuser()
+
+ if not database:
+ database = user
+
+ kwargs.setdefault("application_name", "pgcli")
+
+ # If password prompt is not forced but no password is provided, try
+ # getting it from environment variable.
+ if not self.force_passwd_prompt and not passwd:
+ passwd = os.environ.get("PGPASSWORD", "")
+
+ # Prompt for a password immediately if requested via the -W flag. This
+ # avoids wasting time trying to connect to the database and catching a
+ # no-password exception.
+ # If we successfully parsed a password from a URI, there's no need to
+ # prompt for it, even with the -W flag
+ if self.force_passwd_prompt and not passwd:
+ passwd = click.prompt(
+ "Password for %s" % user, hide_input=True, show_default=False, type=str
+ )
+
+ key = f"{user}@{host}"
+
+ if not passwd and auth.keyring:
+ passwd = auth.keyring_get_password(key)
+
+ def should_ask_for_password(exc):
+ # Prompt for a password after 1st attempt to connect
+ # fails. Don't prompt if the -w flag is supplied
+ if self.never_passwd_prompt:
+ return False
+ error_msg = exc.args[0]
+ if "no password supplied" in error_msg:
+ return True
+ if "password authentication failed" in error_msg:
+ return True
+ return False
+
+ if dsn:
+ parsed_dsn = conninfo_to_dict(dsn)
+ if "host" in parsed_dsn:
+ host = parsed_dsn["host"]
+ if "port" in parsed_dsn:
+ port = parsed_dsn["port"]
+
+ if self.ssh_tunnel_config and not self.ssh_tunnel_url:
+ for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
+ if re.search(db_host_regex, host):
+ self.ssh_tunnel_url = tunnel_url
+ break
+
+ if self.ssh_tunnel_url:
+ # We add the protocol as urlparse doesn't find it by itself
+ if "://" not in self.ssh_tunnel_url:
+ self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
+
+ tunnel_info = urlparse(self.ssh_tunnel_url)
+ params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (host, int(port or 5432)),
+ "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
+ "logger": self.logger,
+ }
+ if tunnel_info.username:
+ params["ssh_username"] = tunnel_info.username
+ if tunnel_info.password:
+ params["ssh_password"] = tunnel_info.password
+
+ # Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
+ # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
+ logger_handlers = self.logger.handlers.copy()
+ try:
+ self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
+ self.ssh_tunnel.start()
+ except Exception as e:
+ self.logger.handlers = logger_handlers
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+ self.logger.handlers = logger_handlers
+
+ atexit.register(self.ssh_tunnel.stop)
+ host = "127.0.0.1"
+ port = self.ssh_tunnel.local_bind_ports[0]
+
+ if dsn:
+ dsn = make_conninfo(dsn, host=host, port=port)
+
+ # Attempt to connect to the database.
+ # Note that passwd may be empty on the first attempt. If connection
+ # fails because of a missing or incorrect password, but we're allowed to
+ # prompt for a password (no -w flag), prompt for a passwd and try again.
+ try:
+ try:
+ pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
+ except (OperationalError, InterfaceError) as e:
+ if should_ask_for_password(e):
+ passwd = click.prompt(
+ "Password for %s" % user,
+ hide_input=True,
+ show_default=False,
+ type=str,
+ )
+ pgexecute = PGExecute(
+ database, user, passwd, host, port, dsn, **kwargs
+ )
+ else:
+ raise e
+ if passwd and auth.keyring:
+ auth.keyring_set_password(key, passwd)
+
+ except Exception as e: # Connecting to a database could fail.
+ self.logger.debug("Database connection failed: %r.", e)
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+
+ self.pgexecute = pgexecute
+
+ def handle_editor_command(self, text):
+ r"""
+ Editor command is any query that is prefixed or suffixed
+ by a '\e'. The reason for a while loop is because a user
+ might edit a query multiple times.
+ For eg:
+ "select * from \e"<enter> to edit it in vim, then come
+ back to the prompt with the edited query "select * from
+ blah where q = 'abc'\e" to edit it again.
+ :param text: Document
+ :return: Document
+ """
+ editor_command = special.editor_command(text)
+ while editor_command:
+ if editor_command == "\\e":
+ filename = special.get_filename(text)
+ query = special.get_editor_query(text) or self.get_last_query()
+ else: # \ev or \ef
+ filename = None
+ spec = text.split()[1]
+ if editor_command == "\\ev":
+ query = self.pgexecute.view_definition(spec)
+ elif editor_command == "\\ef":
+ query = self.pgexecute.function_definition(spec)
+ sql, message = special.open_external_editor(filename, sql=query)
+ if message:
+ # Something went wrong. Raise an exception and bail.
+ raise RuntimeError(message)
+ while True:
+ try:
+ text = self.prompt_app.prompt(default=sql)
+ break
+ except KeyboardInterrupt:
+ sql = ""
+
+ editor_command = special.editor_command(text)
+ return text
+
+ def execute_command(self, text, handle_closed_connection=True):
+ logger = self.logger
+
+ query = MetaQuery(query=text, successful=False)
+
+ try:
+ if self.destructive_warning:
+ if (
+ self.destructive_statements_require_transaction
+ and not self.pgexecute.valid_transaction()
+ and is_destructive(text, self.destructive_warning)
+ ):
+ click.secho(
+ "Destructive statements must be run within a transaction."
+ )
+ raise KeyboardInterrupt
+ destroy = confirm_destructive_query(
+ text, self.destructive_warning, self.dsn_alias
+ )
+ if destroy is False:
+ click.secho("Wise choice!")
+ raise KeyboardInterrupt
+ elif destroy:
+ click.secho("Your call!")
+
+ output, query = self._evaluate_command(text)
+ except KeyboardInterrupt:
+ if self.destructive_warning_restarts_connection:
+ # Restart connection to the database
+ self.pgexecute.connect()
+ logger.debug("cancelled query and restarted connection, sql: %r", text)
+ click.secho(
+ "cancelled query and restarted connection", err=True, fg="red"
+ )
+ else:
+ logger.debug("cancelled query, sql: %r", text)
+ click.secho("cancelled query", err=True, fg="red")
+ except NotImplementedError:
+ click.secho("Not Yet Implemented.", fg="yellow")
+ except OperationalError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ if handle_closed_connection:
+ self._handle_server_closed_connection(text)
+ except (PgCliQuitError, EOFError):
+ raise
+ except Exception as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ else:
+ try:
+ if self.output_file and not text.startswith(
+ ("\\o ", "\\? ", "\\echo ")
+ ):
+ try:
+ with open(self.output_file, "a", encoding="utf-8") as f:
+ click.echo(text, file=f)
+ click.echo("\n".join(output), file=f)
+ click.echo("", file=f) # extra newline
+ except OSError as e:
+ click.secho(str(e), err=True, fg="red")
+ else:
+ if output:
+ self.echo_via_pager("\n".join(output))
+ except KeyboardInterrupt:
+ pass
+
+ if self.pgspecial.timing_enabled:
+ # Only add humanized time display if > 1 second
+ if query.total_time > 1:
+ print(
+ "Time: %0.03fs (%s), executed in: %0.03fs (%s)"
+ % (
+ query.total_time,
+ pendulum.Duration(seconds=query.total_time).in_words(),
+ query.execution_time,
+ pendulum.Duration(seconds=query.execution_time).in_words(),
+ )
+ )
+ else:
+ print("Time: %0.03fs" % query.total_time)
+
+ # Check if we need to update completions, in order of most
+ # to least drastic changes
+ if query.db_changed:
+ with self._completer_lock:
+ self.completer.reset_completions()
+ self.refresh_completions(persist_priorities="keywords")
+ elif query.meta_changed:
+ self.refresh_completions(persist_priorities="all")
+ elif query.path_changed:
+ logger.debug("Refreshing search path")
+ with self._completer_lock:
+ self.completer.set_search_path(self.pgexecute.search_path())
+ logger.debug("Search path: %r", self.completer.search_path)
+ return query
+
+ def _check_ongoing_transaction_and_allow_quitting(self):
+ """Return whether we can really quit, possibly by asking the
+ user to confirm so if there is an ongoing transaction.
+ """
+ if not self.pgexecute.valid_transaction():
+ return True
+ while 1:
+ try:
+ choice = click.prompt(
+ "A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.",
+ default="a",
+ )
+ except click.Abort:
+ # Print newline if user aborts with `^C`, otherwise
+ # pgcli's prompt will be printed on the same line
+ # (just after the confirmation prompt).
+ click.echo(None, err=False)
+ choice = "a"
+ choice = choice.lower()
+ if choice == "a":
+ return False # do not quit
+ if choice == "c":
+ query = self.execute_command("commit")
+ return query.successful # quit only if query is successful
+ if choice == "r":
+ query = self.execute_command("rollback")
+ return query.successful # quit only if query is successful
+
+ def run_cli(self):
+ logger = self.logger
+
+ history_file = self.config["main"]["history_file"]
+ if history_file == "default":
+ history_file = config_location() + "history"
+ history = FileHistory(os.path.expanduser(history_file))
+ self.refresh_completions(history=history, persist_priorities="none")
+
+ self.prompt_app = self._build_cli(history)
+
+ if not self.less_chatty:
+ print("Server: PostgreSQL", self.pgexecute.server_version)
+ print("Version:", __version__)
+ print("Home: http://pgcli.com")
+
+ try:
+ while True:
+ try:
+ text = self.prompt_app.prompt()
+ except KeyboardInterrupt:
+ continue
+ except EOFError:
+ if not self._check_ongoing_transaction_and_allow_quitting():
+ continue
+ raise
+
+ try:
+ text = self.handle_editor_command(text)
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ continue
+
+ try:
+ self.handle_watch_command(text)
+ except PgCliQuitError:
+ if not self._check_ongoing_transaction_and_allow_quitting():
+ continue
+ raise
+
+ self.now = dt.datetime.today()
+
+ # Allow PGCompleter to learn user's preferred keywords, etc.
+ with self._completer_lock:
+ self.completer.extend_query_history(text)
+
+ except (PgCliQuitError, EOFError):
+ if not self.less_chatty:
+ print("Goodbye!")
+
+ def handle_watch_command(self, text):
+ # Initialize default metaquery in case execution fails
+ self.watch_command, timing = special.get_watch_command(text)
+
+ # If we run \watch without a command, apply it to the last query run.
+ if self.watch_command is not None and not self.watch_command.strip():
+ try:
+ self.watch_command = self.query_history[-1].query
+ except IndexError:
+ click.secho(
+ "\\watch cannot be used with an empty query", err=True, fg="red"
+ )
+ self.watch_command = None
+
+ # If there's a command to \watch, run it in a loop.
+ if self.watch_command:
+ while self.watch_command:
+ try:
+ query = self.execute_command(self.watch_command)
+ click.echo(f"Waiting for {timing} seconds before repeating")
+ sleep(timing)
+ except KeyboardInterrupt:
+ self.watch_command = None
+
+ # Otherwise, execute it as a regular command.
+ else:
+ query = self.execute_command(text)
+
+ self.query_history.append(query)
+
+ def _build_cli(self, history):
+ key_bindings = pgcli_bindings(self)
+
+ def get_message():
+ if self.dsn_alias and self.prompt_dsn_format is not None:
+ prompt_format = self.prompt_dsn_format
+ else:
+ prompt_format = self.prompt_format
+
+ prompt = self.get_prompt(prompt_format)
+
+ if (
+ prompt_format == self.default_prompt
+ and len(prompt) > self.max_len_prompt
+ ):
+ prompt = self.get_prompt("\\d> ")
+
+ prompt = prompt.replace("\\x1b", "\x1b")
+ return ANSI(prompt)
+
+ def get_continuation(width, line_number, is_soft_wrap):
+ continuation = self.multiline_continuation_char * (width - 1) + " "
+ return [("class:continuation", continuation)]
+
+ get_toolbar_tokens = create_toolbar_tokens_func(self)
+
+ if self.wider_completion_menu:
+ complete_style = CompleteStyle.MULTI_COLUMN
+ else:
+ complete_style = CompleteStyle.COLUMN
+
+ with self._completer_lock:
+ prompt_app = PromptSession(
+ lexer=PygmentsLexer(PostgresLexer),
+ reserve_space_for_menu=self.min_num_menu_lines,
+ message=get_message,
+ prompt_continuation=get_continuation,
+ bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
+ complete_style=complete_style,
+ input_processors=[
+ # Highlight matching brackets while editing.
+ ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(chars="[](){}"),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
+ ),
+ # Render \t as 4 spaces instead of "^I"
+ TabsProcessor(char1=" ", char2=" "),
+ ],
+ auto_suggest=AutoSuggestFromHistory(),
+ tempfile_suffix=".sql",
+ # N.b. pgcli's multi-line mode controls submit-on-Enter (which
+ # overrides the default behaviour of prompt_toolkit) and is
+ # distinct from prompt_toolkit's multiline mode here, which
+ # controls layout/display of the prompt/buffer
+ multiline=True,
+ history=history,
+ completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)),
+ complete_while_typing=True,
+ style=style_factory(self.syntax_style, self.cli_style),
+ include_default_pygments_style=False,
+ key_bindings=key_bindings,
+ enable_open_in_editor=True,
+ enable_system_prompt=True,
+ enable_suspend=True,
+ editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
+ search_ignore_case=True,
+ )
+
+ return prompt_app
+
+ def _should_limit_output(self, sql, cur):
+ """returns True if the output should be truncated, False otherwise."""
+ if self.explain_mode:
+ return False
+ if not is_select(sql):
+ return False
+
+ return (
+ not self._has_limit(sql)
+ and self.row_limit != 0
+ and cur
+ and cur.rowcount > self.row_limit
+ )
+
+ def _has_limit(self, sql):
+ if not sql:
+ return False
+ return "limit " in sql.lower()
+
+ def _limit_output(self, cur):
+ limit = min(self.row_limit, cur.rowcount)
+ new_cur = itertools.islice(cur, limit)
+ new_status = "SELECT " + str(limit)
+ click.secho("The result was limited to %s rows" % limit, fg="red")
+
+ return new_cur, new_status
+
+ def _evaluate_command(self, text):
+ """Used to run a command entered by the user during CLI operation
+ (Puts the E in REPL)
+
+ returns (results, MetaQuery)
+ """
+ logger = self.logger
+ logger.debug("sql: %r", text)
+
+ # set query to formatter in order to parse table name
+ self.formatter.query = text
+ all_success = True
+ meta_changed = False # CREATE, ALTER, DROP, etc
+ mutated = False # INSERT, DELETE, etc
+ db_changed = False
+ path_changed = False
+ output = []
+ total = 0
+ execution = 0
+
+ # Run the query.
+ start = time()
+ on_error_resume = self.on_error == "RESUME"
+ res = self.pgexecute.run(
+ text,
+ self.pgspecial,
+ exception_formatter,
+ on_error_resume,
+ explain_mode=self.explain_mode,
+ )
+
+ is_special = None
+
+ for title, cur, headers, status, sql, success, is_special in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+
+ if self._should_limit_output(sql, cur):
+ cur, status = self._limit_output(cur)
+
+ if self.pgspecial.auto_expand or self.auto_expand:
+ max_width = self.prompt_app.output.get_size().columns
+ else:
+ max_width = None
+
+ expanded = self.pgspecial.expanded_output or self.expanded_output
+ settings = OutputSettings(
+ table_format=self.table_format,
+ dcmlfmt=self.decimal_format,
+ floatfmt=self.float_format,
+ missingval=self.null_string,
+ expanded=expanded,
+ max_width=max_width,
+ case_function=(
+ self.completer.case
+ if self.settings["case_column_headers"]
+ else lambda x: x
+ ),
+ style_output=self.style_output,
+ max_field_width=self.max_field_width,
+ )
+ execution = time() - start
+ formatted = format_output(
+ title, cur, headers, status, settings, self.explain_mode
+ )
+
+ output.extend(formatted)
+ total = time() - start
+
+ # Keep track of whether any of the queries are mutating or changing
+ # the database
+ if success:
+ mutated = mutated or is_mutating(status)
+ db_changed = db_changed or has_change_db_cmd(sql)
+ meta_changed = meta_changed or has_meta_cmd(sql)
+ path_changed = path_changed or has_change_path_cmd(sql)
+ else:
+ all_success = False
+
+ meta_query = MetaQuery(
+ text,
+ all_success,
+ total,
+ execution,
+ meta_changed,
+ db_changed,
+ path_changed,
+ mutated,
+ is_special,
+ )
+
+ return output, meta_query
+
+ def _handle_server_closed_connection(self, text):
+ """Used during CLI execution."""
+ try:
+ click.secho("Reconnecting...", fg="green")
+ self.pgexecute.connect()
+ click.secho("Reconnected!", fg="green")
+ except OperationalError as e:
+ click.secho("Reconnect Failed", fg="red")
+ click.secho(str(e), err=True, fg="red")
+ else:
+ retry = self.auto_retry_closed_connection or confirm(
+ "Run the query from before reconnecting?"
+ )
+ if retry:
+ click.secho("Running query...", fg="green")
+ # Don't get stuck in a retry loop
+ self.execute_command(text, handle_closed_connection=False)
+
+ def refresh_completions(self, history=None, persist_priorities="all"):
+ """Refresh outdated completions
+
+ :param history: A prompt_toolkit.history.FileHistory object. Used to
+ load keyword and identifier preferences
+
+ :param persist_priorities: 'all' or 'keywords'
+ """
+
+ callback = functools.partial(
+ self._on_completions_refreshed, persist_priorities=persist_priorities
+ )
+ return self.completion_refresher.refresh(
+ self.pgexecute,
+ self.pgspecial,
+ callback,
+ history=history,
+ settings=self.settings,
+ )
+
+ def _on_completions_refreshed(self, new_completer, persist_priorities):
+ self._swap_completer_objects(new_completer, persist_priorities)
+
+ if self.prompt_app:
+ # After refreshing, redraw the CLI to clear the statusbar
+ # "Refreshing completions..." indicator
+ self.prompt_app.app.invalidate()
+
+ def _swap_completer_objects(self, new_completer, persist_priorities):
+ """Swap the completer object with the newly created completer.
+
+ persist_priorities is a string specifying how the old completer's
+ learned prioritizer should be transferred to the new completer.
+
+ 'none' - The new prioritizer is left in a new/clean state
+
+ 'all' - The new prioritizer is updated to exactly reflect
+ the old one
+
+ 'keywords' - The new prioritizer is updated with old keyword
+ priorities, but not any other.
+
+ """
+ with self._completer_lock:
+ old_completer = self.completer
+ self.completer = new_completer
+
+ if persist_priorities == "all":
+ # Just swap over the entire prioritizer
+ new_completer.prioritizer = old_completer.prioritizer
+ elif persist_priorities == "keywords":
+ # Swap over the entire prioritizer, but clear name priorities,
+ # leaving learned keyword priorities alone
+ new_completer.prioritizer = old_completer.prioritizer
+ new_completer.prioritizer.clear_names()
+ elif persist_priorities == "none":
+ # Leave the new prioritizer as is
+ pass
+ self.completer = new_completer
+
+ def get_completions(self, text, cursor_positition):
+ with self._completer_lock:
+ return self.completer.get_completions(
+ Document(text=text, cursor_position=cursor_positition), None
+ )
+
+ def get_prompt(self, string):
+ # should be before replacing \\d
+ string = string.replace("\\dsn_alias", self.dsn_alias or "")
+ string = string.replace("\\t", self.now.strftime("%x %X"))
+ string = string.replace("\\u", self.pgexecute.user or "(none)")
+ string = string.replace("\\H", self.pgexecute.host or "(none)")
+ string = string.replace("\\h", self.pgexecute.short_host or "(none)")
+ string = string.replace("\\d", self.pgexecute.dbname or "(none)")
+ string = string.replace(
+ "\\p",
+ str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
+ )
+ string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
+ string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
+ string = string.replace("\\n", "\n")
+ return string
+
+ def get_last_query(self):
+ """Get the last query executed or None."""
+ return self.query_history[-1][0] if self.query_history else None
+
+ def is_too_wide(self, line):
+ """Will this line be too wide to fit into terminal?"""
+ if not self.prompt_app:
+ return False
+ return (
+ len(COLOR_CODE_REGEX.sub("", line))
+ > self.prompt_app.output.get_size().columns
+ )
+
+ def is_too_tall(self, lines):
+ """Are there too many lines to fit into terminal?"""
+ if not self.prompt_app:
+ return False
+ return len(lines) >= (self.prompt_app.output.get_size().rows - 4)
+
+ def echo_via_pager(self, text, color=None):
+ if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
+ click.echo(text, color=color)
+ elif (
+ self.pgspecial.pager_config == PAGER_LONG_OUTPUT
+ and self.table_format != "csv"
+ ):
+ lines = text.split("\n")
+
+ # The last 4 lines are reserved for the pgcli menu and padding
+ if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
+ click.echo_via_pager(text, color=color)
+ else:
+ click.echo(text, color=color)
+ else:
+ click.echo_via_pager(text, color)
+
+
+@click.command()
+# Default host is '' so psycopg can default to either localhost or unix socket
+@click.option(
+ "-h",
+ "--host",
+ default="",
+ envvar="PGHOST",
+ help="Host address of the postgres database.",
+)
+@click.option(
+ "-p",
+ "--port",
+ default=5432,
+ help="Port number at which the " "postgres instance is listening.",
+ envvar="PGPORT",
+ type=click.INT,
+)
+@click.option(
+ "-U",
+ "--username",
+ "username_opt",
+ help="Username to connect to the postgres database.",
+)
+@click.option(
+ "-u", "--user", "username_opt", help="Username to connect to the postgres database."
+)
+@click.option(
+ "-W",
+ "--password",
+ "prompt_passwd",
+ is_flag=True,
+ default=False,
+ help="Force password prompt.",
+)
+@click.option(
+ "-w",
+ "--no-password",
+ "never_prompt",
+ is_flag=True,
+ default=False,
+ help="Never prompt for password.",
+)
+@click.option(
+ "--single-connection",
+ "single_connection",
+ is_flag=True,
+ default=False,
+ help="Do not use a separate connection for completions.",
+)
+@click.option("-v", "--version", is_flag=True, help="Version of pgcli.")
+@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.")
+@click.option(
+ "--pgclirc",
+ default=config_location() + "config",
+ envvar="PGCLIRC",
+ help="Location of pgclirc file.",
+ type=click.Path(dir_okay=False),
+)
+@click.option(
+ "-D",
+ "--dsn",
+ default="",
+ envvar="DSN",
+ help="Use DSN configured into the [alias_dsn] section of pgclirc file.",
+)
+@click.option(
+ "--list-dsn",
+ "list_dsn",
+ is_flag=True,
+ help="list of DSN configured into the [alias_dsn] section of pgclirc file.",
+)
+@click.option(
+ "--row-limit",
+ default=None,
+ envvar="PGROWLIMIT",
+ type=click.INT,
+ help="Set threshold for row limit prompt. Use 0 to disable prompt.",
+)
+@click.option(
+ "--less-chatty",
+ "less_chatty",
+ is_flag=True,
+ default=False,
+ help="Skip intro on startup and goodbye on exit.",
+)
+@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").')
+@click.option(
+ "--prompt-dsn",
+ help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").',
+)
+@click.option(
+ "-l",
+ "--list",
+ "list_databases",
+ is_flag=True,
+ help="list available databases, then exit.",
+)
+@click.option(
+ "--auto-vertical-output",
+ is_flag=True,
+ help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
+)
+@click.option(
+ "--warn",
+ default=None,
+ help="Warn before running a destructive query.",
+)
+@click.option(
+ "--ssh-tunnel",
+ default=None,
+ help="Open an SSH tunnel to the given address and connect to the database from it.",
+)
+@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
+@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
+def cli(
+ dbname,
+ username_opt,
+ host,
+ port,
+ prompt_passwd,
+ never_prompt,
+ single_connection,
+ dbname_opt,
+ username,
+ version,
+ pgclirc,
+ dsn,
+ row_limit,
+ less_chatty,
+ prompt,
+ prompt_dsn,
+ list_databases,
+ auto_vertical_output,
+ list_dsn,
+ warn,
+ ssh_tunnel: str,
+):
+ if version:
+ print("Version:", __version__)
+ sys.exit(0)
+
+ config_dir = os.path.dirname(config_location())
+ if not os.path.exists(config_dir):
+ os.makedirs(config_dir)
+
+ # Migrate the config file from old location.
+ config_full_path = config_location() + "config"
+ if os.path.exists(os.path.expanduser("~/.pgclirc")):
+ if not os.path.exists(config_full_path):
+ shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path)
+ print("Config file (~/.pgclirc) moved to new location", config_full_path)
+ else:
+ print("Config file is now located at", config_full_path)
+ print(
+ "Please move the existing config file ~/.pgclirc to",
+ config_full_path,
+ )
+ if list_dsn:
+ try:
+ cfg = load_config(pgclirc, config_full_path)
+ for alias in cfg["alias_dsn"]:
+ click.secho(alias + " : " + cfg["alias_dsn"][alias])
+ sys.exit(0)
+ except Exception as err:
+ click.secho(
+ "Invalid DSNs found in the config file. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+
+ if ssh_tunnel and not SSH_TUNNEL_SUPPORT:
+ click.secho(
+ 'Cannot open SSH tunnel, "sshtunnel" package was not found. '
+ "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.",
+ err=True,
+ fg="red",
+ )
+ exit(1)
+
+ pgcli = PGCli(
+ prompt_passwd,
+ never_prompt,
+ pgclirc_file=pgclirc,
+ row_limit=row_limit,
+ single_connection=single_connection,
+ less_chatty=less_chatty,
+ prompt=prompt,
+ prompt_dsn=prompt_dsn,
+ auto_vertical_output=auto_vertical_output,
+ warn=warn,
+ ssh_tunnel_url=ssh_tunnel,
+ )
+
+ # Choose which ever one has a valid value.
+ if dbname_opt and dbname:
+ # work as psql: when database is given as option and argument use the argument as user
+ username = dbname
+ database = dbname_opt or dbname or ""
+ user = username_opt or username
+ service = None
+ if database.startswith("service="):
+ service = database[8:]
+ elif os.getenv("PGSERVICE") is not None:
+ service = os.getenv("PGSERVICE")
+ # because option --list or -l are not supposed to have a db name
+ if list_databases:
+ database = "postgres"
+
+ if dsn != "":
+ try:
+ cfg = load_config(pgclirc, config_full_path)
+ dsn_config = cfg["alias_dsn"][dsn]
+ except KeyError:
+ click.secho(
+ f"Could not find a DSN with alias {dsn}. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+ except Exception:
+ click.secho(
+ "Invalid DSNs found in the config file. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+ pgcli.connect_uri(dsn_config)
+ pgcli.dsn_alias = dsn
+ elif "://" in database:
+ pgcli.connect_uri(database)
+ elif "=" in database and service is None:
+ pgcli.connect_dsn(database, user=user)
+ elif service is not None:
+ pgcli.connect_service(service, user)
+ else:
+ pgcli.connect(database, host, user, port)
+
+ if list_databases:
+ cur, headers, status = pgcli.pgexecute.full_databases()
+
+ title = "List of databases"
+ settings = OutputSettings(table_format="ascii", missingval="<null>")
+ formatted = format_output(title, cur, headers, status, settings)
+ pgcli.echo_via_pager("\n".join(formatted))
+
+ sys.exit(0)
+
+ pgcli.logger.debug(
+ "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r",
+ database,
+ user,
+ host,
+ port,
+ )
+
+ if setproctitle:
+ obfuscate_process_password()
+
+ pgcli.run_cli()
+
+
+def obfuscate_process_password():
+ process_title = setproctitle.getproctitle()
+ if "://" in process_title:
+ process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title)
+ elif "=" in process_title:
+ process_title = re.sub(
+ r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title
+ )
+
+ setproctitle.setproctitle(process_title)
+
+
+def has_meta_cmd(query):
+ """Determines if the completion needs a refresh by checking if the sql
+ statement is an alter, create, drop, commit or rollback."""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"):
+ return True
+ except Exception:
+ return False
+
+ return False
+
+
+def has_change_db_cmd(query):
+ """Determines if the statement is a database switch such as 'use' or '\\c'"""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("use", "\\c", "\\connect"):
+ return True
+ except Exception:
+ return False
+
+ return False
+
+
+def has_change_path_cmd(sql):
+ """Determines if the search_path should be refreshed by checking if the
+ sql has 'set search_path'."""
+ return "set search_path" in sql.lower()
+
+
+def is_mutating(status):
+ """Determines if the statement is mutating based on the status."""
+ if not status:
+ return False
+
+ mutating = {"insert", "update", "delete"}
+ return status.split(None, 1)[0].lower() in mutating
+
+
+def is_select(status):
+ """Returns true if the first word in status is 'select'."""
+ if not status:
+ return False
+ return status.split(None, 1)[0].lower() == "select"
+
+
+def exception_formatter(e):
+ return click.style(str(e), fg="red")
+
+
+def format_output(title, cur, headers, status, settings, explain_mode=False):
+ output = []
+ expanded = settings.expanded or settings.table_format == "vertical"
+ table_format = "vertical" if settings.expanded else settings.table_format
+ max_width = settings.max_width
+ case_function = settings.case_function
+ if explain_mode:
+ formatter = ExplainOutputFormatter(max_width or 100)
+ else:
+ formatter = TabularOutputFormatter(format_name=table_format)
+
+ def format_array(val):
+ if val is None:
+ return settings.missingval
+ if not isinstance(val, list):
+ return val
+ return "{" + ",".join(str(format_array(e)) for e in val) + "}"
+
+ def format_arrays(data, headers, **_):
+ data = list(data)
+ for row in data:
+ row[:] = [
+ format_array(val) if isinstance(val, list) else val for val in row
+ ]
+
+ return data, headers
+
+ def format_status(cur, status):
+ # redshift does not return rowcount as part of status.
+ # See https://github.com/dbcli/pgcli/issues/1320
+ if cur and hasattr(cur, "rowcount") and cur.rowcount is not None:
+ if status and not status.endswith(str(cur.rowcount)):
+ status += " %s" % cur.rowcount
+ return status
+
+ output_kwargs = {
+ "sep_title": "RECORD {n}",
+ "sep_character": "-",
+ "sep_length": (1, 25),
+ "missing_value": settings.missingval,
+ "integer_format": settings.dcmlfmt,
+ "float_format": settings.floatfmt,
+ "preprocessors": (format_numbers, format_arrays),
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "style": settings.style_output,
+ "max_field_width": settings.max_field_width,
+ }
+ if not settings.floatfmt:
+ output_kwargs["preprocessors"] = (align_decimals,)
+
+ if table_format == "csv":
+ # The default CSV dialect is "excel" which is not handling newline values correctly
+ # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n'
+ # as the line terminator
+ # https://github.com/dbcli/pgcli/issues/1102
+ dialect = "excel" if platform.system() == "Windows" else "unix"
+ output_kwargs["dialect"] = dialect
+
+ if title: # Only print the title if it's not None.
+ output.append(title)
+
+ if cur:
+ headers = [case_function(x) for x in headers]
+ if max_width is not None:
+ cur = list(cur)
+ column_types = None
+ if hasattr(cur, "description"):
+ column_types = []
+ for d in cur.description:
+ col_type = cur.adapters.types.get(d.type_code)
+ type_name = col_type.name if col_type else None
+ if type_name in ("numeric", "float4", "float8"):
+ column_types.append(float)
+ if type_name in ("int2", "int4", "int8"):
+ column_types.append(int)
+ else:
+ column_types.append(str)
+
+ formatted = formatter.format_output(cur, headers, **output_kwargs)
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+ first_line = next(formatted)
+ formatted = itertools.chain([first_line], formatted)
+ if (
+ not explain_mode
+ and not expanded
+ and max_width
+ and len(strip_ansi(first_line)) > max_width
+ and headers
+ ):
+ formatted = formatter.format_output(
+ cur,
+ headers,
+ format_name="vertical",
+ column_types=column_types,
+ **output_kwargs,
+ )
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+
+ output = itertools.chain(output, formatted)
+
+ # Only print the status if it's not None
+ if status:
+ output = itertools.chain(output, [format_status(cur, status)])
+
+ return output
+
+
+def parse_service_info(service):
+ service = service or os.getenv("PGSERVICE")
+ service_file = os.getenv("PGSERVICEFILE")
+ if not service_file:
+ # try ~/.pg_service.conf (if that exists)
+ if platform.system() == "Windows":
+ service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf"
+ elif os.getenv("PGSYSCONFDIR"):
+ service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
+ else:
+ service_file = os.path.expanduser("~/.pg_service.conf")
+ if not service or not os.path.exists(service_file):
+ # nothing to do
+ return None, service_file
+ with open(service_file, newline="") as f:
+ skipped_lines = skip_initial_comment(f)
+ try:
+ service_file_config = ConfigObj(f)
+ except ParseError as err:
+ err.line_number += skipped_lines
+ raise err
+ if service not in service_file_config:
+ return None, service_file
+ service_conf = service_file_config.get(service)
+ return service_conf, service_file
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/pgcli/packages/__init__.py b/pgcli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pgcli/packages/__init__.py
diff --git a/pgcli/packages/formatter/__init__.py b/pgcli/packages/formatter/__init__.py
new file mode 100644
index 0000000..9bad579
--- /dev/null
+++ b/pgcli/packages/formatter/__init__.py
@@ -0,0 +1 @@
+# coding=utf-8
diff --git a/pgcli/packages/formatter/sqlformatter.py b/pgcli/packages/formatter/sqlformatter.py
new file mode 100644
index 0000000..5224eff
--- /dev/null
+++ b/pgcli/packages/formatter/sqlformatter.py
@@ -0,0 +1,74 @@
+# coding=utf-8
+
+from pgcli.packages.parseutils.tables import extract_tables
+
+
+supported_formats = (
+ "sql-insert",
+ "sql-update",
+ "sql-update-1",
+ "sql-update-2",
+)
+
+preprocessors = ()
+
+
+def escape_for_sql_statement(value):
+ if value is None:
+ return "NULL"
+
+ if isinstance(value, bytes):
+ return f"X'{value.hex()}'"
+
+ return "'{}'".format(value)
+
+
+def adapter(data, headers, table_format=None, **kwargs):
+ tables = extract_tables(formatter.query)
+ if len(tables) > 0:
+ table = tables[0]
+ if table[0]:
+ table_name = "{}.{}".format(*table[:2])
+ else:
+ table_name = table[1]
+ else:
+ table_name = "DUAL"
+ if table_format == "sql-insert":
+ h = '", "'.join(headers)
+ yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h)
+ prefix = " "
+ for d in data:
+ values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d))
+ yield "{}({})".format(prefix, values)
+ if prefix == " ":
+ prefix = ", "
+ yield ";"
+ if table_format.startswith("sql-update"):
+ s = table_format.split("-")
+ keys = 1
+ if len(s) > 2:
+ keys = int(s[-1])
+ for d in data:
+ yield 'UPDATE "{}" SET'.format(table_name)
+ prefix = " "
+ for i, v in enumerate(d[keys:], keys):
+ yield '{}"{}" = {}'.format(
+ prefix, headers[i], escape_for_sql_statement(v)
+ )
+ if prefix == " ":
+ prefix = ", "
+ f = '"{}" = {}'
+ where = (
+ f.format(headers[i], escape_for_sql_statement(d[i]))
+ for i in range(keys)
+ )
+ yield "WHERE {};".format(" AND ".join(where))
+
+
+def register_new_formatter(TabularOutputFormatter):
+ global formatter
+ formatter = TabularOutputFormatter
+ for sql_format in supported_formats:
+ TabularOutputFormatter.register_new_formatter(
+ sql_format, adapter, preprocessors, {"table_format": sql_format}
+ )
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
new file mode 100644
index 0000000..023e13b
--- /dev/null
+++ b/pgcli/packages/parseutils/__init__.py
@@ -0,0 +1,58 @@
+import sqlparse
+
+
+BASE_KEYWORDS = [
+ "drop",
+ "shutdown",
+ "delete",
+ "truncate",
+ "alter",
+ "unconditional_update",
+]
+ALL_KEYWORDS = BASE_KEYWORDS + ["update"]
+
+
+def query_starts_with(formatted_sql, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def query_is_unconditional_update(formatted_sql):
+ """Check if the query starts with UPDATE and contains no WHERE."""
+ tokens = formatted_sql.split()
+ return bool(tokens) and tokens[0] == "update" and "where" not in tokens
+
+
+def is_destructive(queries, keywords):
+ """Returns if any of the queries in *queries* is destructive."""
+ for query in sqlparse.split(queries):
+ if query:
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
+ if "unconditional_update" in keywords and query_is_unconditional_update(
+ formatted_sql
+ ):
+ return True
+ if query_starts_with(formatted_sql, keywords):
+ return True
+ return False
+
+
+def parse_destructive_warning(warning_level):
+ """Converts a deprecated destructive warning option to a list of command keywords."""
+ if not warning_level:
+ return []
+
+ if not isinstance(warning_level, list):
+ if "," in warning_level:
+ return warning_level.split(",")
+ warning_level = [warning_level]
+
+ return {
+ "true": ALL_KEYWORDS,
+ "false": [],
+ "all": ALL_KEYWORDS,
+ "moderate": BASE_KEYWORDS,
+ "off": [],
+ "": [],
+ }.get(warning_level[0], warning_level)
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py
new file mode 100644
index 0000000..e1f9088
--- /dev/null
+++ b/pgcli/packages/parseutils/ctes.py
@@ -0,0 +1,141 @@
+from sqlparse import parse
+from sqlparse.tokens import Keyword, CTE, DML
+from sqlparse.sql import Identifier, IdentifierList, Parenthesis
+from collections import namedtuple
+from .meta import TableMetadata, ColumnMetadata
+
+
+# TableExpression is a namedtuple representing a CTE, used internally
+# name: cte alias assigned in the query
+# columns: list of column names
+# start: index into the original string of the left parens starting the CTE
+# stop: index into the original string of the right parens ending the CTE
+TableExpression = namedtuple("TableExpression", "name columns start stop")
+
+
+def isolate_query_ctes(full_text, text_before_cursor):
+ """Simplify a query by converting CTEs into table metadata objects"""
+
+ if not full_text or not full_text.strip():
+ return full_text, text_before_cursor, tuple()
+
+ ctes, remainder = extract_ctes(full_text)
+ if not ctes:
+ return full_text, text_before_cursor, ()
+
+ current_position = len(text_before_cursor)
+ meta = []
+
+ for cte in ctes:
+ if cte.start < current_position < cte.stop:
+ # Currently editing a cte - treat its body as the current full_text
+ text_before_cursor = full_text[cte.start : current_position]
+ full_text = full_text[cte.start : cte.stop]
+ return full_text, text_before_cursor, meta
+
+ # Append this cte to the list of available table metadata
+ cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
+ meta.append(TableMetadata(cte.name, cols))
+
+ # Editing past the last cte (ie the main body of the query)
+ full_text = full_text[ctes[-1].stop :]
+ text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]
+
+ return full_text, text_before_cursor, tuple(meta)
+
+
+def extract_ctes(sql):
+ """Extract constant table expresseions from a query
+
+ Returns tuple (ctes, remainder_sql)
+
+ ctes is a list of TableExpression namedtuples
+ remainder_sql is the text from the original query after the CTEs have
+ been stripped.
+ """
+
+ p = parse(sql)[0]
+
+ # Make sure the first meaningful token is "WITH" which is necessary to
+ # define CTEs
+ idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
+ if not (tok and tok.ttype == CTE):
+ return [], sql
+
+ # Get the next (meaningful) token, which should be the first CTE
+ idx, tok = p.token_next(idx)
+ if not tok:
+ return ([], "")
+ start_pos = token_start_pos(p.tokens, idx)
+ ctes = []
+
+ if isinstance(tok, IdentifierList):
+ # Multiple ctes
+ for t in tok.get_identifiers():
+ cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
+ cte = get_cte_from_token(t, start_pos + cte_start_offset)
+ if not cte:
+ continue
+ ctes.append(cte)
+ elif isinstance(tok, Identifier):
+ # A single CTE
+ cte = get_cte_from_token(tok, start_pos)
+ if cte:
+ ctes.append(cte)
+
+ idx = p.token_index(tok) + 1
+
+ # Collapse everything after the ctes into a remainder query
+ remainder = "".join(str(tok) for tok in p.tokens[idx:])
+
+ return ctes, remainder
+
+
+def get_cte_from_token(tok, pos0):
+ cte_name = tok.get_real_name()
+ if not cte_name:
+ return None
+
+ # Find the start position of the opening parens enclosing the cte body
+ idx, parens = tok.token_next_by(Parenthesis)
+ if not parens:
+ return None
+
+ start_pos = pos0 + token_start_pos(tok.tokens, idx)
+ cte_len = len(str(parens)) # includes parens
+ stop_pos = start_pos + cte_len
+
+ column_names = extract_column_names(parens)
+
+ return TableExpression(cte_name, column_names, start_pos, stop_pos)
+
+
+def extract_column_names(parsed):
+ # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
+ idx, tok = parsed.token_next_by(t=DML)
+ tok_val = tok and tok.value.lower()
+
+ if tok_val in ("insert", "update", "delete"):
+ # Jump ahead to the RETURNING clause where the list of column names is
+ idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
+ elif not tok_val == "select":
+ # Must be invalid CTE
+ return ()
+
+ # The next token should be either a column name, or a list of column names
+ idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
+ return tuple(t.get_name() for t in _identifiers(tok))
+
+
+def token_start_pos(tokens, idx):
+ return sum(len(str(t)) for t in tokens[:idx])
+
+
+def _identifiers(tok):
+ if isinstance(tok, IdentifierList):
+ for t in tok.get_identifiers():
+ # NB: IdentifierList.get_identifiers() can return non-identifiers!
+ if isinstance(t, Identifier):
+ yield t
+ elif isinstance(tok, Identifier):
+ yield tok
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
new file mode 100644
index 0000000..333cab5
--- /dev/null
+++ b/pgcli/packages/parseutils/meta.py
@@ -0,0 +1,170 @@
+from collections import namedtuple
+
+_ColumnMetadata = namedtuple(
+ "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
+)
+
+
+def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
+ return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
+
+
+ForeignKey = namedtuple(
+ "ForeignKey",
+ [
+ "parentschema",
+ "parenttable",
+ "parentcolumn",
+ "childschema",
+ "childtable",
+ "childcolumn",
+ ],
+)
+TableMetadata = namedtuple("TableMetadata", "name columns")
+
+
+def parse_defaults(defaults_string):
+ """Yields default values for a function, given the string provided by
+ pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
+ if not defaults_string:
+ return
+ current = ""
+ in_quote = None
+ for char in defaults_string:
+ if current == "" and char == " ":
+ # Skip space after comma separating default expressions
+ continue
+ if char == '"' or char == "'":
+ if in_quote and char == in_quote:
+ # End quote
+ in_quote = None
+ elif not in_quote:
+ # Begin quote
+ in_quote = char
+ elif char == "," and not in_quote:
+ # End of expression
+ yield current
+ current = ""
+ continue
+ current += char
+ yield current
+
+
+class FunctionMetadata:
+ def __init__(
+ self,
+ schema_name,
+ func_name,
+ arg_names,
+ arg_types,
+ arg_modes,
+ return_type,
+ is_aggregate,
+ is_window,
+ is_set_returning,
+ is_extension,
+ arg_defaults,
+ ):
+ """Class for describing a postgresql function"""
+
+ self.schema_name = schema_name
+ self.func_name = func_name
+
+ self.arg_modes = tuple(arg_modes) if arg_modes else None
+ self.arg_names = tuple(arg_names) if arg_names else None
+
+ # Be flexible in not requiring arg_types -- use None as a placeholder
+ # for each arg. (Used for compatibility with old versions of postgresql
+ # where such info is hard to get.
+ if arg_types:
+ self.arg_types = tuple(arg_types)
+ elif arg_modes:
+ self.arg_types = tuple([None] * len(arg_modes))
+ elif arg_names:
+ self.arg_types = tuple([None] * len(arg_names))
+ else:
+ self.arg_types = None
+
+ self.arg_defaults = tuple(parse_defaults(arg_defaults))
+
+ self.return_type = return_type.strip()
+ self.is_aggregate = is_aggregate
+ self.is_window = is_window
+ self.is_set_returning = is_set_returning
+ self.is_extension = bool(is_extension)
+ self.is_public = self.schema_name and self.schema_name == "public"
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _signature(self):
+ return (
+ self.schema_name,
+ self.func_name,
+ self.arg_names,
+ self.arg_types,
+ self.arg_modes,
+ self.return_type,
+ self.is_aggregate,
+ self.is_window,
+ self.is_set_returning,
+ self.is_extension,
+ self.arg_defaults,
+ )
+
+ def __hash__(self):
+ return hash(self._signature())
+
+ def __repr__(self):
+ return (
+ "%s(schema_name=%r, func_name=%r, arg_names=%r, "
+ "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
+ "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
+ ) % ((self.__class__.__name__,) + self._signature())
+
+ def has_variadic(self):
+ return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
+
+ def args(self):
+ """Returns a list of input-parameter ColumnMetadata namedtuples."""
+ if not self.arg_names:
+ return []
+ modes = self.arg_modes or ["i"] * len(self.arg_names)
+ args = [
+ (name, typ)
+ for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
+ if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
+ ]
+
+ def arg(name, typ, num):
+ num_args = len(args)
+ num_defaults = len(self.arg_defaults)
+ has_default = num + num_defaults >= num_args
+ default = (
+ self.arg_defaults[num - num_args + num_defaults]
+ if has_default
+ else None
+ )
+ return ColumnMetadata(name, typ, [], default, has_default)
+
+ return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
+
+ def fields(self):
+ """Returns a list of output-field ColumnMetadata namedtuples"""
+
+ if self.return_type.lower() == "void":
+ return []
+ elif not self.arg_modes:
+ # For functions without output parameters, the function name
+ # is used as the name of the output column.
+ # E.g. 'SELECT unnest FROM unnest(...);'
+ return [ColumnMetadata(self.func_name, self.return_type, [])]
+
+ return [
+ ColumnMetadata(name, typ, [])
+ for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
+ if mode in ("o", "b", "t")
+ ] # OUT, INOUT, TABLE
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
new file mode 100644
index 0000000..9098115
--- /dev/null
+++ b/pgcli/packages/parseutils/tables.py
@@ -0,0 +1,165 @@
+import sqlparse
+from collections import namedtuple
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+TableReference = namedtuple(
+ "TableReference", ["schema", "name", "alias", "is_function"]
+)
+TableReference.ref = property(
+ lambda self: self.alias
+ or (
+ self.name
+ if self.name.islower() or self.name[0] == '"'
+ else '"' + self.name + '"'
+ )
+)
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
+ ):
+ return True
+ return False
+
+
+def _identifier_is_function(identifier):
+ return any(isinstance(t, Function) for t in identifier.tokens)
+
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ yield from extract_from_part(item, stop_at_punctuation)
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # `return`. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
+ ):
+ tbl_prefix_seen = False
+ else:
+ yield item
+ elif item.ttype is Keyword or item.ttype is Keyword.DML:
+ item_val = item.value.upper()
+ if item_val in (
+ "COPY",
+ "FROM",
+ "INTO",
+ "UPDATE",
+ "TABLE",
+ ) or item_val.endswith("JOIN"):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
+ tbl_prefix_seen = True
+ break
+
+
+def extract_table_identifiers(token_stream, allow_functions=True):
+ """yields tuples of TableReference namedtuples"""
+
+ # We need to do some massaging of the names because postgres is case-
+ # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
+ def parse_identifier(item):
+ name = item.get_real_name()
+ schema_name = item.get_parent_name()
+ alias = item.get_alias()
+ if not name:
+ schema_name = None
+ name = item.get_name()
+ alias = alias or name
+ schema_quoted = schema_name and item.value[0] == '"'
+ if schema_name and not schema_quoted:
+ schema_name = schema_name.lower()
+ quote_count = item.value.count('"')
+ name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
+ alias_quoted = alias and item.value[-1] == '"'
+ if alias_quoted or name_quoted and not alias and name.islower():
+ alias = '"' + (alias or name) + '"'
+ if name and not name_quoted and not name.islower():
+ if not alias:
+ alias = name
+ name = name.lower()
+ return schema_name, name, alias
+
+ try:
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ is_function = allow_functions and _identifier_is_function(
+ identifier
+ )
+ except AttributeError:
+ continue
+ if real_name:
+ yield TableReference(
+ schema_name, real_name, identifier.get_alias(), is_function
+ )
+ elif isinstance(item, Identifier):
+ schema_name, real_name, alias = parse_identifier(item)
+ is_function = allow_functions and _identifier_is_function(item)
+
+ yield TableReference(schema_name, real_name, alias, is_function)
+ elif isinstance(item, Function):
+ schema_name, real_name, alias = parse_identifier(item)
+ yield TableReference(None, real_name, alias, allow_functions)
+ except StopIteration:
+ return
+
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statement.
+
+ Returns a list of TableReference namedtuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return ()
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == "insert"
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+
+ # Kludge: sqlparse mistakenly identifies insert statements as
+ # function calls due to the parenthesized column list, e.g. interprets
+ # "insert into foo (bar, baz)" as a function call to foo with arguments
+ # (bar, baz). So don't allow any identifiers in insert statements
+ # to have is_function=True
+ identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt)
+ # In the case 'sche.<cursor>', we get an empty TableReference; remove that
+ return tuple(i for i in identifiers if i.name)
diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py
new file mode 100644
index 0000000..034c96e
--- /dev/null
+++ b/pgcli/packages/parseutils/utils.py
@@ -0,0 +1,140 @@
+import re
+import sqlparse
+from sqlparse.sql import Identifier
+from sqlparse.tokens import Token, Error
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ "alphanum_underscore": re.compile(r"(\w+)$"),
+ # This matches everything except spaces, parens, colon, and comma
+ "many_punctuations": re.compile(r"([^():,\s]+)$"),
+ # This matches everything except spaces, parens, colon, comma, and period
+ "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
+ # This matches everything except a space.
+ "all_punctuations": re.compile(r"([^\s]+)$"),
+}
+
+
+def last_word(text, include="alphanum_underscore"):
+ r"""
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ >>> last_word('"foo*bar', include='most_punctuations')
+ '"foo*bar'
+ """
+
+ if not text: # Empty string
+ return ""
+
+ if text[-1].isspace():
+ return ""
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ""
+
+
+def find_prev_keyword(sql, n_skip=0):
+ """Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ""
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+ flattened = flattened[: len(flattened) - n_skip]
+
+ logical_operators = ("AND", "OR", "NOT", "BETWEEN")
+
+ for t in reversed(flattened):
+ if t.value == "(" or (
+ t.is_keyword and (t.value.upper() not in logical_operators)
+ ):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index throws an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = "".join(tok.value for tok in flattened[: idx + 1])
+ return t, text
+
+ return None, ""
+
+
+# Postgresql dollar quote signs look like `$$` or `$tag$`
+dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
+
+
+def is_open_quote(sql):
+ """Returns true if the query contains an unclosed quote"""
+
+ # parsed can contain one or more semi-colon separated commands
+ parsed = sqlparse.parse(sql)
+ return any(_parsed_is_open_quote(p) for p in parsed)
+
+
+def _parsed_is_open_quote(parsed):
+ # Look for unmatched single quotes, or unmatched dollar sign quotes
+ return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
+
+
+def parse_partial_identifier(word):
+ """Attempt to parse a (partially typed) word as an identifier
+
+ word may include a schema qualification, like `schema_name.partial_name`
+ or `schema_name.` There may also be unclosed quotation marks, like
+ `"schema`, or `schema."partial_name`
+
+ :param word: string representing a (partially complete) identifier
+ :return: sqlparse.sql.Identifier, or None
+ """
+
+ p = sqlparse.parse(word)[0]
+ n_tok = len(p.tokens)
+ if n_tok == 1 and isinstance(p.tokens[0], Identifier):
+ return p.tokens[0]
+ elif p.token_next_by(m=(Error, '"'))[1]:
+ # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
+ # Close the double quote, then reparse
+ return parse_partial_identifier(word + '"')
+ else:
+ return None
diff --git a/pgcli/packages/pgliterals/__init__.py b/pgcli/packages/pgliterals/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pgcli/packages/pgliterals/__init__.py
diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py
new file mode 100644
index 0000000..5c39296
--- /dev/null
+++ b/pgcli/packages/pgliterals/main.py
@@ -0,0 +1,15 @@
+import os
+import json
+
+root = os.path.dirname(__file__)
+literal_file = os.path.join(root, "pgliterals.json")
+
+with open(literal_file) as f:
+ literals = json.load(f)
+
+
+def get_literals(literal_type, type_=tuple):
+ # Where `literal_type` is one of 'keywords', 'functions', 'datatypes',
+ # returns a tuple of literal values of that type.
+
+ return type_(literals[literal_type])
diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json
new file mode 100644
index 0000000..df00817
--- /dev/null
+++ b/pgcli/packages/pgliterals/pgliterals.json
@@ -0,0 +1,630 @@
+{
+ "keywords": {
+ "ACCESS": [],
+ "ADD": [],
+ "ALL": [],
+ "ALTER": [
+ "AGGREGATE",
+ "COLLATION",
+ "COLUMN",
+ "CONVERSION",
+ "DATABASE",
+ "DEFAULT",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN",
+ "FUNCTION",
+ "GROUP",
+ "INDEX",
+ "LANGUAGE",
+ "LARGE OBJECT",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "SYSTEM",
+ "TABLE",
+ "TABLESPACE",
+ "TEXT SEARCH",
+ "TRIGGER",
+ "TYPE",
+ "USER",
+ "VIEW"
+ ],
+ "AND": [],
+ "ANY": [],
+ "AS": [],
+ "ASC": [],
+ "AUDIT": [],
+ "BEGIN": [],
+ "BETWEEN": [],
+ "BY": [],
+ "CASE": [],
+ "CHAR": [],
+ "CHECK": [],
+ "CLUSTER": [],
+ "COLUMN": [],
+ "COMMENT": [],
+ "COMMIT": [],
+ "COMPRESS": [],
+ "CONCURRENTLY": [],
+ "CONNECT": [],
+ "COPY": [],
+ "CREATE": [
+ "ACCESS METHOD",
+ "AGGREGATE",
+ "CAST",
+ "COLLATION",
+ "CONVERSION",
+ "DATABASE",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN DATA WRAPPER",
+ "FOREIGN EXTENSION",
+ "FUNCTION",
+ "GLOBAL",
+ "GROUP",
+ "IF NOT EXISTS",
+ "INDEX",
+ "LANGUAGE",
+ "LOCAL",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "OR REPLACE",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "TABLE",
+ "TABLESPACE",
+ "TEMPORARY",
+ "TEXT SEARCH",
+ "TRIGGER",
+ "TYPE",
+ "UNIQUE",
+ "UNLOGGED",
+ "USER",
+ "USER MAPPING",
+ "VIEW"
+ ],
+ "CURRENT": [],
+ "DATABASE": [],
+ "DATE": [],
+ "DECIMAL": [],
+ "DEFAULT": [],
+ "DELETE FROM": [],
+ "DELIMITER": [],
+ "DESC": [],
+ "DESCRIBE": [],
+ "DISTINCT": [],
+ "DROP": [
+ "ACCESS METHOD",
+ "AGGREGATE",
+ "CAST",
+ "COLLATION",
+ "COLUMN",
+ "CONVERSION",
+ "DATABASE",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN DATA WRAPPER",
+ "FOREIGN TABLE",
+ "FUNCTION",
+ "GROUP",
+ "INDEX",
+ "LANGUAGE",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "OWNED",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "TABLE",
+ "TABLESPACE",
+ "TEXT SEARCH",
+ "TRANSFORM",
+ "TRIGGER",
+ "TYPE",
+ "USER",
+ "USER MAPPING",
+ "VIEW"
+ ],
+ "EXPLAIN": [],
+ "ELSE": [],
+ "ENCODING": [],
+ "ESCAPE": [],
+ "EXCLUSIVE": [],
+ "EXISTS": [],
+ "EXTENSION": [],
+ "FILE": [],
+ "FLOAT": [],
+ "FOR": [],
+ "FORMAT": [],
+ "FORCE_QUOTE": [],
+ "FORCE_NOT_NULL": [],
+ "FREEZE": [],
+ "FROM": [],
+ "FULL": [],
+ "FUNCTION": [],
+ "GRANT": [],
+ "GROUP BY": [],
+ "HAVING": [],
+ "HEADER": [],
+ "IDENTIFIED": [],
+ "IMMEDIATE": [],
+ "IN": [],
+ "INCREMENT": [],
+ "INDEX": [],
+ "INITIAL": [],
+ "INSERT INTO": [],
+ "INTEGER": [],
+ "INTERSECT": [],
+ "INTERVAL": [],
+ "INTO": [],
+ "IS": [],
+ "JOIN": [],
+ "LANGUAGE": [],
+ "LEFT": [],
+ "LEVEL": [],
+ "LIKE": [],
+ "LIMIT": [],
+ "LOCK": [],
+ "LONG": [],
+ "MATERIALIZED VIEW": [],
+ "MAXEXTENTS": [],
+ "MINUS": [],
+ "MLSLABEL": [],
+ "MODE": [],
+ "MODIFY": [],
+ "NOT": [],
+ "NOAUDIT": [],
+ "NOTICE": [],
+ "NOCOMPRESS": [],
+ "NOWAIT": [],
+ "NULL": [],
+ "NUMBER": [],
+ "OIDS": [],
+ "OF": [],
+ "OFFLINE": [],
+ "ON": [],
+ "ONLINE": [],
+ "OPTION": [],
+ "OR": [],
+ "ORDER BY": [],
+ "OUTER": [],
+ "OWNER": [],
+ "PCTFREE": [],
+ "PRIMARY": [],
+ "PRIOR": [],
+ "PRIVILEGES": [],
+ "QUOTE": [],
+ "RAISE": [],
+ "RENAME": [],
+ "REPLACE": [],
+ "RESET": ["ALL"],
+ "RAW": [],
+ "REFRESH MATERIALIZED VIEW": [],
+ "RESOURCE": [],
+ "RETURNS": [],
+ "REVOKE": [],
+ "RIGHT": [],
+ "ROLLBACK": [],
+ "ROW": [],
+ "ROWID": [],
+ "ROWNUM": [],
+ "ROWS": [],
+ "SELECT": [],
+ "SESSION": [],
+ "SET": [],
+ "SHARE": [],
+ "SHOW": [],
+ "SIZE": [],
+ "SMALLINT": [],
+ "START": [],
+ "SUCCESSFUL": [],
+ "SYNONYM": [],
+ "SYSDATE": [],
+ "TABLE": [],
+ "TEMPLATE": [],
+ "THEN": [],
+ "TO": [],
+ "TRIGGER": [],
+ "TRUNCATE": [],
+ "UID": [],
+ "UNION": [],
+ "UNIQUE": [],
+ "UPDATE": [],
+ "USE": [],
+ "USER": [],
+ "USING": [],
+ "VALIDATE": [],
+ "VALUES": [],
+ "VARCHAR": [],
+ "VARCHAR2": [],
+ "VIEW": [],
+ "WHEN": [],
+ "WHENEVER": [],
+ "WHERE": [],
+ "WITH": []
+ },
+ "functions": [
+ "ABBREV",
+ "ABS",
+ "AGE",
+ "AREA",
+ "ARRAY_AGG",
+ "ARRAY_APPEND",
+ "ARRAY_CAT",
+ "ARRAY_DIMS",
+ "ARRAY_FILL",
+ "ARRAY_LENGTH",
+ "ARRAY_LOWER",
+ "ARRAY_NDIMS",
+ "ARRAY_POSITION",
+ "ARRAY_POSITIONS",
+ "ARRAY_PREPEND",
+ "ARRAY_REMOVE",
+ "ARRAY_REPLACE",
+ "ARRAY_TO_STRING",
+ "ARRAY_UPPER",
+ "ASCII",
+ "AVG",
+ "BIT_AND",
+ "BIT_LENGTH",
+ "BIT_OR",
+ "BOOL_AND",
+ "BOOL_OR",
+ "BOUND_BOX",
+ "BOX",
+ "BROADCAST",
+ "BTRIM",
+ "CARDINALITY",
+ "CBRT",
+ "CEIL",
+ "CEILING",
+ "CENTER",
+ "CHAR_LENGTH",
+ "CHR",
+ "CIRCLE",
+ "CLOCK_TIMESTAMP",
+ "CONCAT",
+ "CONCAT_WS",
+ "CONVERT",
+ "CONVERT_FROM",
+ "CONVERT_TO",
+ "COUNT",
+ "CUME_DIST",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "DATE_PART",
+ "DATE_TRUNC",
+ "DECODE",
+ "DEGREES",
+ "DENSE_RANK",
+ "DIAMETER",
+ "DIV",
+ "ENCODE",
+ "ENUM_FIRST",
+ "ENUM_LAST",
+ "ENUM_RANGE",
+ "EVERY",
+ "EXP",
+ "EXTRACT",
+ "FAMILY",
+ "FIRST_VALUE",
+ "FLOOR",
+ "FORMAT",
+ "GET_BIT",
+ "GET_BYTE",
+ "HEIGHT",
+ "HOST",
+ "HOSTMASK",
+ "INET_MERGE",
+ "INET_SAME_FAMILY",
+ "INITCAP",
+ "ISCLOSED",
+ "ISFINITE",
+ "ISOPEN",
+ "JUSTIFY_DAYS",
+ "JUSTIFY_HOURS",
+ "JUSTIFY_INTERVAL",
+ "LAG",
+ "LAST_VALUE",
+ "LEAD",
+ "LEFT",
+ "LENGTH",
+ "LINE",
+ "LN",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "LOG",
+ "LOG10",
+ "LOWER",
+ "LPAD",
+ "LSEG",
+ "LTRIM",
+ "MAKE_DATE",
+ "MAKE_INTERVAL",
+ "MAKE_TIME",
+ "MAKE_TIMESTAMP",
+ "MAKE_TIMESTAMPTZ",
+ "MASKLEN",
+ "MAX",
+ "MD5",
+ "MIN",
+ "MOD",
+ "NETMASK",
+ "NETWORK",
+ "NOW",
+ "NPOINTS",
+ "NTH_VALUE",
+ "NTILE",
+ "NUM_NONNULLS",
+ "NUM_NULLS",
+ "OCTET_LENGTH",
+ "OVERLAY",
+ "PARSE_IDENT",
+ "PATH",
+ "PCLOSE",
+ "PERCENT_RANK",
+ "PG_CLIENT_ENCODING",
+ "PI",
+ "POINT",
+ "POLYGON",
+ "POPEN",
+ "POSITION",
+ "POWER",
+ "QUOTE_IDENT",
+ "QUOTE_LITERAL",
+ "QUOTE_NULLABLE",
+ "RADIANS",
+ "RADIUS",
+ "RANDOM",
+ "RANK",
+ "REGEXP_MATCH",
+ "REGEXP_MATCHES",
+ "REGEXP_REPLACE",
+ "REGEXP_SPLIT_TO_ARRAY",
+ "REGEXP_SPLIT_TO_TABLE",
+ "REPEAT",
+ "REPLACE",
+ "REVERSE",
+ "RIGHT",
+ "ROUND",
+ "ROW_NUMBER",
+ "RPAD",
+ "RTRIM",
+ "SCALE",
+ "SET_BIT",
+ "SET_BYTE",
+ "SET_MASKLEN",
+ "SHA224",
+ "SHA256",
+ "SHA384",
+ "SHA512",
+ "SIGN",
+ "SPLIT_PART",
+ "SQRT",
+ "STARTS_WITH",
+ "STATEMENT_TIMESTAMP",
+ "STRING_TO_ARRAY",
+ "STRPOS",
+ "SUBSTR",
+ "SUBSTRING",
+ "SUM",
+ "TEXT",
+ "TIMEOFDAY",
+ "TO_ASCII",
+ "TO_CHAR",
+ "TO_DATE",
+ "TO_HEX",
+ "TO_NUMBER",
+ "TO_TIMESTAMP",
+ "TRANSACTION_TIMESTAMP",
+ "TRANSLATE",
+ "TRIM",
+ "TRUNC",
+ "UNNEST",
+ "UPPER",
+ "WIDTH",
+ "WIDTH_BUCKET",
+ "XMLAGG"
+ ],
+ "datatypes": [
+ "ANY",
+ "ANYARRAY",
+ "ANYELEMENT",
+ "ANYENUM",
+ "ANYNONARRAY",
+ "ANYRANGE",
+ "BIGINT",
+ "BIGSERIAL",
+ "BIT",
+ "BIT VARYING",
+ "BOOL",
+ "BOOLEAN",
+ "BOX",
+ "BYTEA",
+ "CHAR",
+ "CHARACTER",
+ "CHARACTER VARYING",
+ "CIDR",
+ "CIRCLE",
+ "CSTRING",
+ "DATE",
+ "DECIMAL",
+ "DOUBLE PRECISION",
+ "EVENT_TRIGGER",
+ "FDW_HANDLER",
+ "FLOAT4",
+ "FLOAT8",
+ "INET",
+ "INT",
+ "INT2",
+ "INT4",
+ "INT8",
+ "INTEGER",
+ "INTERNAL",
+ "INTERVAL",
+ "JSON",
+ "JSONB",
+ "LANGUAGE_HANDLER",
+ "LINE",
+ "LSEG",
+ "MACADDR",
+ "MACADDR8",
+ "MONEY",
+ "NUMERIC",
+ "OID",
+ "OPAQUE",
+ "PATH",
+ "PG_LSN",
+ "POINT",
+ "POLYGON",
+ "REAL",
+ "RECORD",
+ "REGCLASS",
+ "REGCONFIG",
+ "REGDICTIONARY",
+ "REGNAMESPACE",
+ "REGOPER",
+ "REGOPERATOR",
+ "REGPROC",
+ "REGPROCEDURE",
+ "REGROLE",
+ "REGTYPE",
+ "SERIAL",
+ "SERIAL2",
+ "SERIAL4",
+ "SERIAL8",
+ "SMALLINT",
+ "SMALLSERIAL",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "TRIGGER",
+ "TSQUERY",
+ "TSVECTOR",
+ "TXID_SNAPSHOT",
+ "UUID",
+ "VARBIT",
+ "VARCHAR",
+ "VOID",
+ "XML"
+ ],
+ "reserved": [
+ "ALL",
+ "ANALYSE",
+ "ANALYZE",
+ "AND",
+ "ANY",
+ "ARRAY",
+ "AS",
+ "ASC",
+ "ASYMMETRIC",
+ "BOTH",
+ "CASE",
+ "CAST",
+ "CHECK",
+ "COLLATE",
+ "COLUMN",
+ "CONSTRAINT",
+ "CREATE",
+ "CURRENT_CATALOG",
+ "CURRENT_DATE",
+ "CURRENT_ROLE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "CURRENT_USER",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DESC",
+ "DISTINCT",
+ "DO",
+ "ELSE",
+ "END",
+ "EXCEPT",
+ "FALSE",
+ "FETCH",
+ "FOR",
+ "FOREIGN",
+ "FROM",
+ "GRANT",
+ "GROUP",
+ "HAVING",
+ "IN",
+ "INITIALLY",
+ "INTERSECT",
+ "INTO",
+ "LATERAL",
+ "LEADING",
+ "LIMIT",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "NOT",
+ "NULL",
+ "OFFSET",
+ "ON",
+ "ONLY",
+ "OR",
+ "ORDER",
+ "PLACING",
+ "PRIMARY",
+ "REFERENCES",
+ "RETURNING",
+ "SELECT",
+ "SESSION_USER",
+ "SOME",
+ "SYMMETRIC",
+ "TABLE",
+ "THEN",
+ "TO",
+ "TRAILING",
+ "TRUE",
+ "UNION",
+ "UNIQUE",
+ "USER",
+ "USING",
+ "VARIADIC",
+ "WHEN",
+ "WHERE",
+ "WINDOW",
+ "WITH",
+ "AUTHORIZATION",
+ "BINARY",
+ "COLLATION",
+ "CONCURRENTLY",
+ "CROSS",
+ "CURRENT_SCHEMA",
+ "FREEZE",
+ "FULL",
+ "ILIKE",
+ "INNER",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "LEFT",
+ "LIKE",
+ "NATURAL",
+ "NOTNULL",
+ "OUTER",
+ "OVERLAPS",
+ "RIGHT",
+ "SIMILAR",
+ "TABLESAMPLE",
+ "VERBOSE"
+ ]
+}
diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py
new file mode 100644
index 0000000..f5a9cb5
--- /dev/null
+++ b/pgcli/packages/prioritization.py
@@ -0,0 +1,51 @@
+import re
+import sqlparse
+from sqlparse.tokens import Name
+from collections import defaultdict
+from .pgliterals.main import get_literals
+
+
+white_space_regex = re.compile("\\s+", re.MULTILINE)
+
+
+def _compile_regex(keyword):
+ # Surround the keyword with word boundaries and replace interior whitespace
+ # with whitespace wildcards
+ pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
+ return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
+
+
+keywords = get_literals("keywords")
+keyword_regexs = {kw: _compile_regex(kw) for kw in keywords}
+
+
+class PrevalenceCounter:
+ def __init__(self):
+ self.keyword_counts = defaultdict(int)
+ self.name_counts = defaultdict(int)
+
+ def update(self, text):
+ self.update_keywords(text)
+ self.update_names(text)
+
+ def update_names(self, text):
+ for parsed in sqlparse.parse(text):
+ for token in parsed.flatten():
+ if token.ttype in Name:
+ self.name_counts[token.value] += 1
+
+ def clear_names(self):
+ self.name_counts = defaultdict(int)
+
+ def update_keywords(self, text):
+ # Count keywords. Can't rely for sqlparse for this, because it's
+ # database agnostic
+ for keyword, regex in keyword_regexs.items():
+ for _ in regex.finditer(text):
+ self.keyword_counts[keyword] += 1
+
+ def keyword_count(self, keyword):
+ return self.keyword_counts[keyword]
+
+ def name_count(self, name):
+ return self.name_counts[name]
diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py
new file mode 100644
index 0000000..997b86e
--- /dev/null
+++ b/pgcli/packages/prompt_utils.py
@@ -0,0 +1,37 @@
+import sys
+import click
+from .parseutils import is_destructive
+
+
+def confirm_destructive_query(queries, keywords, alias):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ info = "You're about to run a destructive command"
+ if alias:
+ info += f" in {click.style(alias, fg='red')}"
+
+ prompt_text = f"{info}.\nDo you want to proceed?"
+ if is_destructive(queries, keywords) and sys.stdin.isatty():
+ return confirm(prompt_text)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
new file mode 100644
index 0000000..b78edd6
--- /dev/null
+++ b/pgcli/packages/sqlcompletion.py
@@ -0,0 +1,605 @@
+import sys
+import re
+import sqlparse
+from collections import namedtuple
+from sqlparse.sql import Comparison, Identifier, Where
+from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier
+from .parseutils.tables import extract_tables
+from .parseutils.ctes import isolate_query_ctes
+from pgspecial.main import parse_special_command
+
+
+Special = namedtuple("Special", [])
+Database = namedtuple("Database", [])
+Schema = namedtuple("Schema", ["quoted"])
+Schema.__new__.__defaults__ = (False,)
+# FromClauseItem is a table/view/function used in the FROM clause
+# `table_refs` contains the list of tables/... already in the statement,
+# used to ensure that the alias we suggest is unique
+FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
+Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
+TableFormat = namedtuple("TableFormat", [])
+View = namedtuple("View", ["schema", "table_refs"])
+# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
+JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
+# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
+Join = namedtuple("Join", ["table_refs", "schema"])
+
+Function = namedtuple("Function", ["schema", "table_refs", "usage"])
+# For convenience, don't require the `usage` argument in Function constructor
+Function.__new__.__defaults__ = (None, tuple(), None)
+Table.__new__.__defaults__ = (None, tuple(), tuple())
+View.__new__.__defaults__ = (None, tuple())
+FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
+
+Column = namedtuple(
+ "Column",
+ ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"],
+)
+Column.__new__.__defaults__ = (None, None, tuple(), False, None)
+
+Keyword = namedtuple("Keyword", ["last_token"])
+Keyword.__new__.__defaults__ = (None,)
+NamedQuery = namedtuple("NamedQuery", [])
+Datatype = namedtuple("Datatype", ["schema"])
+Alias = namedtuple("Alias", ["aliases"])
+
+Path = namedtuple("Path", [])
+
+
+class SqlStatement:
+ def __init__(self, full_text, text_before_cursor):
+ self.identifier = None
+ self.word_before_cursor = word_before_cursor = last_word(
+ text_before_cursor, include="many_punctuations"
+ )
+ full_text = _strip_named_query(full_text)
+ text_before_cursor = _strip_named_query(text_before_cursor)
+
+ full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
+ full_text, text_before_cursor
+ )
+
+ self.text_before_cursor_including_last_word = text_before_cursor
+
+ # If we've partially typed a word then word_before_cursor won't be an
+ # empty string. In that case we want to remove the partially typed
+ # string before sending it to the sqlparser. Otherwise the last token
+ # will always be the partially typed string which renders the smart
+ # completion useless because it will always return the list of
+ # keywords as completion.
+ if self.word_before_cursor:
+ if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ text_before_cursor = text_before_cursor[: -len(word_before_cursor)]
+ parsed = sqlparse.parse(text_before_cursor)
+ self.identifier = parse_partial_identifier(word_before_cursor)
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+
+ full_text, text_before_cursor, parsed = _split_multiple_statements(
+ full_text, text_before_cursor, parsed
+ )
+
+ self.full_text = full_text
+ self.text_before_cursor = text_before_cursor
+ self.parsed = parsed
+
+ self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or ""
+
+ def is_insert(self):
+ return self.parsed.token_first().value.lower() == "insert"
+
+ def get_tables(self, scope="full"):
+ """Gets the tables available in the statement.
+ param `scope:` possible values: 'full', 'insert', 'before'
+ If 'insert', only the first table is returned.
+ If 'before', only tables before the cursor are returned.
+ If not 'insert' and the stmt is an insert, the first table is skipped.
+ """
+ tables = extract_tables(
+ self.full_text if scope == "full" else self.text_before_cursor
+ )
+ if scope == "insert":
+ tables = tables[:1]
+ elif self.is_insert():
+ tables = tables[1:]
+ return tables
+
+ def get_previous_token(self, token):
+ return self.parsed.token_prev(self.parsed.token_index(token))[1]
+
+ def get_identifier_schema(self):
+ schema = (self.identifier and self.identifier.get_parent_name()) or None
+ # If schema name is unquoted, lower-case it
+ if schema and self.identifier.value[0] != '"':
+ schema = schema.lower()
+
+ return schema
+
+ def reduce_to_prev_keyword(self, n_skip=0):
+ prev_keyword, self.text_before_cursor = find_prev_keyword(
+ self.text_before_cursor, n_skip=n_skip
+ )
+ return prev_keyword
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ if full_text.startswith("\\i "):
+ return (Path(),)
+
+ # This is a temporary hack; the exception handling
+ # here should be removed once sqlparse has been fixed
+ try:
+ stmt = SqlStatement(full_text, text_before_cursor)
+ except (TypeError, AttributeError):
+ return []
+
+ # Check for special commands and handle those separately
+ if stmt.parsed:
+ # Be careful here because trivial whitespace is parsed as a
+ # statement, but the statement won't have a first token
+ tok1 = stmt.parsed.token_first()
+ if tok1 and tok1.value.startswith("\\"):
+ text = stmt.text_before_cursor + stmt.word_before_cursor
+ return suggest_special(text)
+
+ return suggest_based_on_last_token(stmt.last_token, stmt)
+
+
+named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
+
+
+def _strip_named_query(txt):
+ """
+ This will strip "save named query" command in the beginning of the line:
+ '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ """
+
+ if named_query_regex.match(txt):
+ txt = named_query_regex.sub("", txt)
+ return txt
+
+
+function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
+
+
+def _find_function_body(text):
+ split = function_body_pattern.search(text)
+ return (split.start(2), split.end(2)) if split else (None, None)
+
+
+def _statement_from_function(full_text, text_before_cursor, statement):
+ current_pos = len(text_before_cursor)
+ body_start, body_end = _find_function_body(full_text)
+ if body_start is None:
+ return full_text, text_before_cursor, statement
+ if not body_start <= current_pos < body_end:
+ return full_text, text_before_cursor, statement
+ full_text = full_text[body_start:body_end]
+ text_before_cursor = text_before_cursor[body_start:]
+ parsed = sqlparse.parse(text_before_cursor)
+ return _split_multiple_statements(full_text, text_before_cursor, parsed)
+
+
+def _split_multiple_statements(full_text, text_before_cursor, parsed):
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds
+ # the current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(str(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ return full_text, text_before_cursor, None
+
+ token2 = None
+ if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
+ token1 = statement.token_first()
+ if token1:
+ token1_idx = statement.token_index(token1)
+ token2 = statement.token_next(token1_idx)[1]
+ if token2 and token2.value.upper() == "FUNCTION":
+ full_text, text_before_cursor, statement = _statement_from_function(
+ full_text, text_before_cursor, statement
+ )
+ return full_text, text_before_cursor, statement
+
+
+SPECIALS_SUGGESTION = {
+ "dT": Datatype,
+ "df": Function,
+ "dt": Table,
+ "dv": View,
+ "sf": Function,
+}
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return (Special(),)
+
+ if cmd in ("\\c", "\\connect"):
+ return (Database(),)
+
+ if cmd == "\\T":
+ return (TableFormat(),)
+
+ if cmd == "\\dn":
+ return (Schema(),)
+
+ if arg:
+ # Try to distinguish "\d name" from "\d schema.name"
+ # Note that this will fail to obtain a schema name if wildcards are
+ # used, e.g. "\d schema???.name"
+ parsed = sqlparse.parse(arg)[0].tokens[0]
+ try:
+ schema = parsed.get_parent_name()
+ except AttributeError:
+ schema = None
+ else:
+ schema = None
+
+ if cmd[1:] == "d":
+ # \d can describe tables or views
+ if schema:
+ return (Table(schema=schema), View(schema=schema))
+ else:
+ return (Schema(), Table(schema=None), View(schema=None))
+ elif cmd[1:] in SPECIALS_SUGGESTION:
+ rel_type = SPECIALS_SUGGESTION[cmd[1:]]
+ if schema:
+ if rel_type == Function:
+ return (Function(schema=schema, usage="special"),)
+ return (rel_type(schema=schema),)
+ else:
+ if rel_type == Function:
+ return (Schema(), Function(schema=None, usage="special"))
+ return (Schema(), rel_type(schema=None))
+
+ if cmd in ["\\n", "\\ns", "\\nd"]:
+ return (NamedQuery(),)
+
+ return (Keyword(), Special())
+
+
+def suggest_based_on_last_token(token, stmt):
+ if isinstance(token, str):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword = stmt.reduce_to_prev_keyword()
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ elif isinstance(token, Identifier):
+ # If the previous token is an identifier, we can suggest datatypes if
+ # we're in a parenthesized column/field list, e.g.:
+ # CREATE TABLE foo (Identifier <CURSOR>
+ # CREATE FUNCTION foo (Identifier <CURSOR>
+ # If we're not in a parenthesized list, the most likely scenario is the
+ # user is about to specify an alias, e.g.:
+ # SELECT Identifier <CURSOR>
+ # SELECT foo FROM Identifier <CURSOR>
+ prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
+ if prev_keyword and prev_keyword.value == "(":
+ # Suggest datatypes
+ return suggest_based_on_last_token("type", stmt)
+ else:
+ return (Keyword(),)
+ else:
+ token_v = token.value.lower()
+
+ if not token:
+ return (Keyword(), Special())
+ elif token_v.endswith("("):
+ p = sqlparse.parse(stmt.text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token("where", stmt)
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ prev_tok = where.token_prev(len(where.tokens) - 1)[1]
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == "exists":
+ return (Keyword(),)
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ prev_tok = p.token_prev(len(p.tokens) - 1)[1]
+
+ if (
+ prev_tok
+ and prev_tok.value
+ and prev_tok.value.lower().split(" ")[-1] == "using"
+ ):
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = stmt.get_tables("before")
+
+ # suggest columns that are present in more than one table
+ return (
+ Column(
+ table_refs=tables,
+ require_last_table=True,
+ local_tables=stmt.local_tables,
+ ),
+ )
+
+ elif p.token_first().value.lower() == "select":
+ # If the lparen is preceded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("):
+ return (Keyword(),)
+ prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
+ if prev_prev_tok and prev_prev_tok.normalized == "INTO":
+ return (Column(table_refs=stmt.get_tables("insert"), context="insert"),)
+ # We're probably in a function argument list
+ return _suggest_expression(token_v, stmt)
+ elif token_v == "set":
+ return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),)
+ elif token_v in ("select", "where", "having", "order by", "distinct"):
+ return _suggest_expression(token_v, stmt)
+ elif token_v == "as":
+ # Don't suggest anything for aliases
+ return ()
+ elif (token_v.endswith("join") and token.is_keyword) or (
+ token_v in ("copy", "from", "update", "into", "describe", "truncate")
+ ):
+ schema = stmt.get_identifier_schema()
+ tables = extract_tables(stmt.text_before_cursor)
+ is_join = token_v.endswith("join") and token.is_keyword
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = []
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, Schema())
+
+ if token_v == "from" or is_join:
+ suggest.append(
+ FromClauseItem(
+ schema=schema, table_refs=tables, local_tables=stmt.local_tables
+ )
+ )
+ elif token_v == "truncate":
+ suggest.append(Table(schema))
+ else:
+ suggest.extend((Table(schema), View(schema)))
+
+ if is_join and _allow_join(stmt.parsed):
+ tables = stmt.get_tables("before")
+ suggest.append(Join(table_refs=tables, schema=schema))
+
+ return tuple(suggest)
+
+ elif token_v == "function":
+ schema = stmt.get_identifier_schema()
+
+ # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:`
+ try:
+ prev = stmt.get_previous_token(token).value.lower()
+ if prev in ("drop", "alter", "create", "create or replace"):
+ # Suggest functions from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = []
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, Schema())
+
+ suggest.append(Function(schema=schema, usage="signature"))
+ return tuple(suggest)
+
+ except ValueError:
+ pass
+ return tuple()
+
+ elif token_v in ("table", "view"):
+ # E.g. 'ALTER TABLE <tablname>'
+ rel_type = {"table": Table, "view": View, "function": Function}[token_v]
+ schema = stmt.get_identifier_schema()
+ if schema:
+ return (rel_type(schema=schema),)
+ else:
+ return (Schema(), rel_type(schema=schema))
+
+ elif token_v == "column":
+ # E.g. 'ALTER TABLE foo ALTER COLUMN bar
+ return (Column(table_refs=stmt.get_tables()),)
+
+ elif token_v == "on":
+ tables = stmt.get_tables("before")
+ parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ filteredtables = tuple(t for t in tables if identifies(parent, t))
+ sugs = [
+ Column(table_refs=filteredtables, local_tables=stmt.local_tables),
+ Table(schema=parent),
+ View(schema=parent),
+ Function(schema=parent),
+ ]
+ if filteredtables and _allow_join_condition(stmt.parsed):
+ sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1]))
+ return tuple(sugs)
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = tuple(t.ref for t in tables)
+ if _allow_join_condition(stmt.parsed):
+ return (
+ Alias(aliases=aliases),
+ JoinCondition(table_refs=tables, parent=None),
+ )
+ else:
+ return (Alias(aliases=aliases),)
+
+ elif token_v in ("c", "use", "database", "template"):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return (Database(),)
+ elif token_v == "schema":
+ # DROP SCHEMA schema_name, SET SCHEMA schema name
+ prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
+ quoted = prev_keyword and prev_keyword.value.lower() == "set"
+ return (Schema(quoted),)
+ elif token_v.endswith(",") or token_v in ("=", "and", "or"):
+ prev_keyword = stmt.reduce_to_prev_keyword()
+ if prev_keyword:
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ else:
+ return ()
+ elif token_v in ("type", "::"):
+ # ALTER TABLE foo SET DATA TYPE bar
+ # SELECT foo::bar
+ # Note that tables are a form of composite type in postgresql, so
+ # they're suggested here as well
+ schema = stmt.get_identifier_schema()
+ suggestions = [Datatype(schema=schema), Table(schema=schema)]
+ if not schema:
+ suggestions.append(Schema())
+ return tuple(suggestions)
+ elif token_v in {"alter", "create", "drop"}:
+ return (Keyword(token_v.upper()),)
+ elif token.is_keyword:
+ # token is a keyword we haven't implemented any special handling for
+ # go backwards in the query until we find one we do recognize
+ prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
+ if prev_keyword:
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ else:
+ return (Keyword(token_v.upper()),)
+ else:
+ return (Keyword(),)
+
+
+def _suggest_expression(token_v, stmt):
+ """
+ Return suggestions for an expression, taking account of any partially-typed
+ identifier's parent, which may be a table alias or schema name.
+ """
+ parent = stmt.identifier.get_parent_name() if stmt.identifier else []
+ tables = stmt.get_tables()
+
+ if parent:
+ tables = tuple(t for t in tables if identifies(parent, t))
+ return (
+ Column(table_refs=tables, local_tables=stmt.local_tables),
+ Table(schema=parent),
+ View(schema=parent),
+ Function(schema=parent),
+ )
+
+ return (
+ Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True),
+ Function(schema=None),
+ Keyword(token_v.upper()),
+ )
+
+
+def identifies(id, ref):
+ """Returns true if string `id` matches TableReference `ref`"""
+
+ return (
+ id == ref.alias
+ or id == ref.name
+ or (ref.schema and (id == ref.schema + "." + ref.name))
+ )
+
+
+def _allow_join_condition(statement):
+ """
+ Tests if a join condition should be suggested
+
+ We need this to avoid bad suggestions when entering e.g.
+ select * from tbl1 a join tbl2 b on a.id = <cursor>
+ So check that the preceding token is a ON, AND, or OR keyword, instead of
+ e.g. an equals sign.
+
+ :param statement: an sqlparse.sql.Statement
+ :return: boolean
+ """
+
+ if not statement or not statement.tokens:
+ return False
+
+ last_tok = statement.token_prev(len(statement.tokens))[1]
+ return last_tok.value.lower() in ("on", "and", "or")
+
+
+def _allow_join(statement):
+ """
+ Tests if a join should be suggested
+
+ We need this to avoid bad suggestions when entering e.g.
+ select * from tbl1 a join tbl2 b <cursor>
+ So check that the preceding token is a JOIN keyword
+
+ :param statement: an sqlparse.sql.Statement
+ :return: boolean
+ """
+
+ if not statement or not statement.tokens:
+ return False
+
+ last_tok = statement.token_prev(len(statement.tokens))[1]
+ return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in (
+ "cross join",
+ "natural join",
+ )
diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py
new file mode 100644
index 0000000..c236c13
--- /dev/null
+++ b/pgcli/pgbuffer.py
@@ -0,0 +1,61 @@
+import logging
+
+from prompt_toolkit.enums import DEFAULT_BUFFER
+from prompt_toolkit.filters import Condition
+from prompt_toolkit.application import get_app
+from .packages.parseutils.utils import is_open_quote
+
+_logger = logging.getLogger(__name__)
+
+
+def _is_complete(sql):
+ # A complete command is an sql statement that ends with a semicolon, unless
+ # there's an open quote surrounding it, as is common when writing a
+ # CREATE FUNCTION command
+ return sql.endswith(";") and not is_open_quote(sql)
+
+
+"""
+Returns True if the buffer contents should be handled (i.e. the query/command
+executed) immediately. This is necessary as we use prompt_toolkit in multiline
+mode, which by default will insert new lines on Enter.
+"""
+
+
+def safe_multi_line_mode(pgcli):
+ @Condition
+ def cond():
+ _logger.debug(
+ 'Multi-line mode state: "%s" / "%s"', pgcli.multi_line, pgcli.multiline_mode
+ )
+ return pgcli.multi_line and (pgcli.multiline_mode == "safe")
+
+ return cond
+
+
+def buffer_should_be_handled(pgcli):
+ @Condition
+ def cond():
+ if not pgcli.multi_line:
+ _logger.debug("Not in multi-line mode. Handle the buffer.")
+ return True
+
+ if pgcli.multiline_mode == "safe":
+ _logger.debug("Multi-line mode is set to 'safe'. Do NOT handle the buffer.")
+ return False
+
+ doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
+ text = doc.text.strip()
+
+ return (
+ text.startswith("\\") # Special Command
+ or text.endswith(r"\e") # Special Command
+ or text.endswith(r"\G") # Ended with \e which should launch the editor
+ or _is_complete(text) # A complete SQL command
+ or (text == "exit") # Exit doesn't need semi-colon
+ or (text == "quit") # Quit doesn't need semi-colon
+ or (text == ":q") # To all the vim fans out there
+ or (text == "") # Just a plain enter without any text
+ )
+
+ return cond
diff --git a/pgcli/pgclirc b/pgcli/pgclirc
new file mode 100644
index 0000000..51f7eae
--- /dev/null
+++ b/pgcli/pgclirc
@@ -0,0 +1,236 @@
+# vi: ft=dosini
+[main]
+
+# Enables context sensitive auto-completion. If this is disabled, all
+# possible completions will be listed.
+smart_completion = True
+
+# Display the completions in several columns. (More completions will be
+# visible.)
+wider_completion_menu = False
+
+# Do not create new connections for refreshing completions; Equivalent to
+# always running with the --single-connection flag.
+always_use_single_connection = False
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# If multi_line_mode is set to "psql", in multi-line mode, [Enter] will execute
+# the current input if the input ends in a semicolon.
+# If multi_line_mode is set to "safe", in multi-line mode, [Enter] will always
+# insert a newline, and [Esc] [Enter] or [Alt]-[Enter] must be used to execute
+# a command.
+multi_line_mode = psql
+
+# Destructive warning will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database",
+# "shutdown", "delete", or "update".
+# You can pass a list of destructive commands or leave it empty if you want to skip all warnings.
+# "unconditional_update" will warn you of update statements that don't have a where clause
+destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update
+
+# Destructive warning can restart the connection if this is enabled and the
+# user declines. This means that any current uncommitted transaction can be
+# aborted if the user doesn't want to proceed with a destructive_warning
+# statement.
+destructive_warning_restarts_connection = False
+
+# When this option is on (and if `destructive_warning` is not empty),
+# destructive statements are not executed when outside of a transaction.
+destructive_statements_require_transaction = False
+
+# Enables expand mode, which is similar to `\x` in psql.
+expand = False
+
+# Enables auto expand mode, which is similar to `\x auto` in psql.
+auto_expand = False
+
+# Auto-retry queries on connection failures and other operational errors. If
+# False, will prompt to rerun the failed query instead of auto-retrying.
+auto_retry_closed_connection = True
+
+# If set to True, table suggestions will include a table alias
+generate_aliases = False
+
+# Path to a json file that specifies specific table aliases to use when generate_aliases is set to True
+# the format for this file should be:
+# {
+# "some_table_name": "desired_alias",
+# "some_other_table_name": "another_alias"
+# }
+alias_map_file =
+
+# log_file location.
+# In Unix/Linux: ~/.config/pgcli/log
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log
+# %USERPROFILE% is typically C:\Users\{username}
+log_file = default
+
+# keyword casing preference. Possible values: "lower", "upper", "auto"
+keyword_casing = auto
+
+# casing_file location.
+# In Unix/Linux: ~/.config/pgcli/casing
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\casing
+# %USERPROFILE% is typically C:\Users\{username}
+casing_file = default
+
+# If generate_casing_file is set to True and there is no file in the above
+# location, one will be generated based on usage in SQL/PLPGSQL functions.
+generate_casing_file = False
+
+# Casing of column headers based on the casing_file described above
+case_column_headers = True
+
+# history_file location.
+# In Unix/Linux: ~/.config/pgcli/history
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\history
+# %USERPROFILE% is typically C:\Users\{username}
+history_file = default
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Order of columns when expanding * to column list
+# Possible values: "table_order" and "alphabetic"
+asterisk_column_order = table_order
+
+# Whether to qualify with table alias/name when suggesting columns
+# Possible values: "always", "never" and "if_more_than_one_table"
+qualify_columns = if_more_than_one_table
+
+# When no schema is entered, only suggest objects in search_path
+search_path_filter = False
+
+# Default pager. See https://www.pgcli.com/pager for more information on settings.
+# By default 'PAGER' environment variable is used. If the pager is less, and the 'LESS'
+# environment variable is not set, then LESS='-SRXF' will be automatically set.
+# pager = less
+
+# Timing of sql statements and table rendering.
+timing = True
+
+# Show/hide the informational toolbar with function keymap at the footer.
+show_bottom_toolbar = True
+
+# Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe,
+# ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs,
+# textile, moinmoin, jira, vertical, tsv, csv, sql-insert, sql-update,
+# sql-update-1, sql-update-2 (formatter with sql-* prefix can format query
+# output to executable insertion or updating sql).
+# Recommended: psql, fancy_grid and grid.
+table_format = psql
+
+# Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt,
+# native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark,
+# colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity
+syntax_style = default
+
+# Keybindings:
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+# When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E
+# for end are available in the REPL.
+vi = False
+
+# Error handling
+# When one of multiple SQL statements causes an error, choose to either
+# continue executing the remaining statements, or stopping
+# Possible values "STOP" or "RESUME"
+on_error = STOP
+
+# Set threshold for row limit. Use 0 to disable limiting.
+row_limit = 1000
+
+# Truncate long text fields to this value for tabular display (does not apply to csv).
+# Leave unset to disable truncation. Example: "max_field_width = "
+# Be aware that formatting might get slow with values larger than 500 and tables with
+# lots of records.
+max_field_width = 500
+
+# Skip intro on startup and goodbye on exit
+less_chatty = False
+
+# Postgres prompt
+# \t - Current date and time
+# \u - Username
+# \h - Short hostname of the server (up to first '.')
+# \H - Hostname of the server
+# \d - Database name
+# \p - Database port
+# \i - Postgres PID
+# \# - "@" sign if logged in as superuser, '>' in other case
+# \n - Newline
+# \dsn_alias - name of dsn connection string alias if -D option is used (empty otherwise)
+# \x1b[...m - insert ANSI escape sequence
+# eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>'
+prompt = '\u@\h:\d> '
+
+# Number of lines to reserve for the suggestion menu
+min_num_menu_lines = 4
+
+# Character used to left pad multi-line queries to match the prompt size.
+multiline_continuation_char = ''
+
+# The string used in place of a null value.
+null_string = '<null>'
+
+# manage pager on startup
+enable_pager = True
+
+# Use keyring to automatically save and load password in a secure manner
+keyring = True
+
+# Custom colors for the completion menu, toolbar, etc.
+[colors]
+completion-menu.completion.current = 'bg:#ffffff #000000'
+completion-menu.completion = 'bg:#008888 #ffffff'
+completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
+completion-menu.meta.completion = 'bg:#448888 #ffffff'
+completion-menu.multi-column-meta = 'bg:#aaffff #000000'
+scrollbar.arrow = 'bg:#003333'
+scrollbar = 'bg:#00aaaa'
+selected = '#ffffff bg:#6666aa'
+search = '#ffffff bg:#4444aa'
+search.current = '#ffffff bg:#44aa44'
+bottom-toolbar = 'bg:#222222 #aaaaaa'
+bottom-toolbar.off = 'bg:#222222 #888888'
+bottom-toolbar.on = 'bg:#222222 #ffffff'
+search-toolbar = 'noinherit bold'
+search-toolbar.text = 'nobold'
+system-toolbar = 'noinherit bold'
+arg-toolbar = 'noinherit bold'
+arg-toolbar.text = 'nobold'
+bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
+bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
+# These three values can be used to further refine the syntax highlighting.
+# They are commented out by default, since they have priority over the theme set
+# with the `syntax_style` setting and overriding its behavior can be confusing.
+# literal.string = '#ba2121'
+# literal.number = '#666666'
+# keyword = 'bold #008000'
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+output.null = "#808080"
+
+# Named queries are queries you can execute by name.
+[named queries]
+
+# Here's where you can provide a list of connection string aliases.
+# You can use it by passing the -D option. `pgcli -D example_dsn`
+[alias_dsn]
+# example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname]
+
+# Format for number representation
+# for decimal "d" - 12345678, ",d" - 12,345,678
+# for float "g" - 123456.78, ",g" - 123,456.78
+[data_formats]
+decimal = ""
+float = ""
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
new file mode 100644
index 0000000..17fc540
--- /dev/null
+++ b/pgcli/pgcompleter.py
@@ -0,0 +1,1072 @@
+import json
+import logging
+import re
+from itertools import count, repeat, chain
+import operator
+from collections import namedtuple, defaultdict, OrderedDict
+from cli_helpers.tabular_output import TabularOutputFormatter
+from pgspecial.namedqueries import NamedQueries
+from prompt_toolkit.completion import Completer, Completion, PathCompleter
+from prompt_toolkit.document import Document
+from .packages.sqlcompletion import (
+ FromClauseItem,
+ suggest_type,
+ Special,
+ Database,
+ Schema,
+ Table,
+ TableFormat,
+ Function,
+ Column,
+ View,
+ Keyword,
+ NamedQuery,
+ Datatype,
+ Alias,
+ Path,
+ JoinCondition,
+ Join,
+)
+from .packages.parseutils.meta import ColumnMetadata, ForeignKey
+from .packages.parseutils.utils import last_word
+from .packages.parseutils.tables import TableReference
+from .packages.pgliterals.main import get_literals
+from .packages.prioritization import PrevalenceCounter
+from .config import load_config, config_location
+
+_logger = logging.getLogger(__name__)
+
+Match = namedtuple("Match", ["completion", "priority"])
+
+_SchemaObject = namedtuple("SchemaObject", "name schema meta")
+
+
+def SchemaObject(name, schema=None, meta=None):
+ return _SchemaObject(name, schema, meta)
+
+
+_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display")
+
+
+def Candidate(
+ completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
+):
+ return _Candidate(
+ completion, prio, meta, synonyms or [completion], prio2, display or completion
+ )
+
+
+# Used to strip trailing '::some_type' from default-value expressions
+arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
+
+normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
+
+
+def generate_alias(tbl, alias_map=None):
+ """Generate a table alias, consisting of all upper-case letters in
+ the table name, or, if there are no upper-case letters, the first letter +
+ all letters preceded by _
+ param tbl - unescaped name of the table to alias
+ """
+ if alias_map and tbl in alias_map:
+ return alias_map[tbl]
+ return "".join(
+ [l for l in tbl if l.isupper()]
+ or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]
+ )
+
+
+class InvalidMapFile(ValueError):
+ pass
+
+
+def load_alias_map_file(path):
+ try:
+ with open(path) as fo:
+ alias_map = json.load(fo)
+ except FileNotFoundError as err:
+ raise InvalidMapFile(
+ f"Cannot read alias_map_file - {err.filename} does not exist"
+ )
+ except json.JSONDecodeError:
+ raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json")
+ else:
+ return alias_map
+
+
+class PGCompleter(Completer):
+ # keywords_tree: A dict mapping keywords to well known following keywords.
+ # e.g. 'CREATE': ['TABLE', 'USER', ...],
+ keywords_tree = get_literals("keywords", type_=dict)
+ keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values())))
+ functions = get_literals("functions")
+ datatypes = get_literals("datatypes")
+ reserved_words = set(get_literals("reserved"))
+
+ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
+ super().__init__()
+ self.smart_completion = smart_completion
+ self.pgspecial = pgspecial
+ self.prioritizer = PrevalenceCounter()
+ settings = settings or {}
+ self.signature_arg_style = settings.get(
+ "signature_arg_style", "{arg_name} {arg_type}"
+ )
+ self.call_arg_style = settings.get(
+ "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}"
+ )
+ self.call_arg_display_style = settings.get(
+ "call_arg_display_style", "{arg_name}"
+ )
+ self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2)
+ self.search_path_filter = settings.get("search_path_filter")
+ self.generate_aliases = settings.get("generate_aliases")
+ alias_map_file = settings.get("alias_map_file")
+ if alias_map_file is not None:
+ self.alias_map = load_alias_map_file(alias_map_file)
+ else:
+ self.alias_map = None
+ self.casing_file = settings.get("casing_file")
+ self.insert_col_skip_patterns = [
+ re.compile(pattern)
+ for pattern in settings.get(
+ "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("]
+ )
+ ]
+ self.generate_casing_file = settings.get("generate_casing_file")
+ self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table")
+ self.asterisk_column_order = settings.get(
+ "asterisk_column_order", "table_order"
+ )
+
+ keyword_casing = settings.get("keyword_casing", "upper").lower()
+ if keyword_casing not in ("upper", "lower", "auto"):
+ keyword_casing = "upper"
+ self.keyword_casing = keyword_casing
+ self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
+
+ self.databases = []
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
+ self.search_path = []
+ self.casing = {}
+
+ self.all_completions = set(self.keywords + self.functions)
+
+ def escape_name(self, name):
+ if name and (
+ (not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)
+ ):
+ name = '"%s"' % name
+
+ return name
+
+ def escape_schema(self, name):
+ return "'{}'".format(self.unescape_name(name))
+
+ def unescape_name(self, name):
+ """Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_schemata(self, schemata):
+ # schemata is a list of schema names
+ schemata = self.escaped_names(schemata)
+ metadata = self.dbmetadata["tables"]
+ for schema in schemata:
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ for schema in schemata:
+ metadata[schema] = {}
+
+ self.all_completions.update(schemata)
+
+ def extend_casing(self, words):
+ """extend casing data
+
+ :return:
+ """
+ # casing should be a dict {lowercasename:PreferredCasingName}
+ self.casing = {word.lower(): word for word in words}
+
+ def extend_relations(self, data, kind):
+ """extend metadata for tables or views.
+
+ :param data: list of (schema_name, rel_name) tuples
+ :param kind: either 'tables' or 'views'
+
+ :return:
+
+ """
+
+ data = [self.escaped_names(d) for d in data]
+
+ # dbmetadata['tables']['schema_name']['table_name'] should be an
+ # OrderedDict {column_name:ColumnMetaData}.
+ metadata = self.dbmetadata[kind]
+ for schema, relname in data:
+ try:
+ metadata[schema][relname] = OrderedDict()
+ except KeyError:
+ _logger.error(
+ "%r %r listed in unrecognized schema %r", kind, relname, schema
+ )
+ self.all_completions.add(relname)
+
+ def extend_columns(self, column_data, kind):
+ """extend column metadata.
+
+ :param column_data: list of (schema_name, rel_name, column_name,
+ column_type, has_default, default) tuples
+ :param kind: either 'tables' or 'views'
+
+ :return:
+
+ """
+ metadata = self.dbmetadata[kind]
+ for schema, relname, colname, datatype, has_default, default in column_data:
+ (schema, relname, colname) = self.escaped_names([schema, relname, colname])
+ column = ColumnMetadata(
+ name=colname,
+ datatype=datatype,
+ has_default=has_default,
+ default=default,
+ )
+ metadata[schema][relname][colname] = column
+ self.all_completions.add(colname)
+
+ def extend_functions(self, func_data):
+ # func_data is a list of function metadata namedtuples
+
+ # dbmetadata['schema_name']['functions']['function_name'] should return
+ # the function metadata namedtuple for the corresponding function
+ metadata = self.dbmetadata["functions"]
+
+ for f in func_data:
+ schema, func = self.escaped_names([f.schema_name, f.func_name])
+
+ if func in metadata[schema]:
+ metadata[schema][func].append(f)
+ else:
+ metadata[schema][func] = [f]
+
+ self.all_completions.add(func)
+
+ self._refresh_arg_list_cache()
+
+ def _refresh_arg_list_cache(self):
+ # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}}
+ # This is used when suggesting functions, to avoid the latency that would result
+ # if we'd recalculate the arg lists each time we suggest functions (in large DBs)
+ self._arg_list_cache = {
+ usage: {
+ meta: self._arg_list(meta, usage)
+ for sch, funcs in self.dbmetadata["functions"].items()
+ for func, metas in funcs.items()
+ for meta in metas
+ }
+ for usage in ("call", "call_display", "signature")
+ }
+
+ def extend_foreignkeys(self, fk_data):
+ # fk_data is a list of ForeignKey namedtuples, with fields
+ # parentschema, childschema, parenttable, childtable,
+ # parentcolumns, childcolumns
+
+ # These are added as a list of ForeignKey namedtuples to the
+ # ColumnMetadata namedtuple for both the child and parent
+ meta = self.dbmetadata["tables"]
+
+ for fk in fk_data:
+ e = self.escaped_names
+ parentschema, childschema = e([fk.parentschema, fk.childschema])
+ parenttable, childtable = e([fk.parenttable, fk.childtable])
+ childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
+ childcolmeta = meta[childschema][childtable][childcol]
+ parcolmeta = meta[parentschema][parenttable][parcol]
+ fk = ForeignKey(
+ parentschema, parenttable, parcol, childschema, childtable, childcol
+ )
+ childcolmeta.foreignkeys.append(fk)
+ parcolmeta.foreignkeys.append(fk)
+
+ def extend_datatypes(self, type_data):
+ # dbmetadata['datatypes'][schema_name][type_name] should store type
+ # metadata, such as composite type field names. Currently, we're not
+ # storing any metadata beyond typename, so just store None
+ meta = self.dbmetadata["datatypes"]
+
+ for t in type_data:
+ schema, type_name = self.escaped_names(t)
+ meta[schema][type_name] = None
+ self.all_completions.add(type_name)
+
+ def extend_query_history(self, text, is_init=False):
+ if is_init:
+ # During completer initialization, only load keyword preferences,
+ # not names
+ self.prioritizer.update_keywords(text)
+ else:
+ self.prioritizer.update(text)
+
+ def set_search_path(self, search_path):
+ self.search_path = self.escaped_names(search_path)
+
+ def reset_completions(self):
+ self.databases = []
+ self.special_commands = []
+ self.search_path = []
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ def find_matches(self, text, collection, mode="fuzzy", meta=None):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ `collection` can be either a list of strings or a list of Candidate
+ namedtuples.
+ `mode` can be either 'fuzzy', or 'strict'
+ 'fuzzy': fuzzy matching, ties broken by name prevalance
+ `keyword`: start only matching, ties broken by keyword prevalance
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+
+ """
+ if not collection:
+ return []
+ prio_order = [
+ "keyword",
+ "function",
+ "view",
+ "table",
+ "datatype",
+ "database",
+ "schema",
+ "column",
+ "table alias",
+ "join",
+ "name join",
+ "fk join",
+ "table format",
+ ]
+ type_priority = prio_order.index(meta) if meta in prio_order else -1
+ text = last_word(text, include="most_punctuations").lower()
+ text_len = len(text)
+
+ if text and text[0] == '"':
+ # text starts with double quote; user is manually escaping a name
+ # Match on everything that follows the double-quote. Note that
+ # text_len is calculated before removing the quote, so the
+ # Completion.position value is correct
+ text = text[1:]
+
+ if mode == "fuzzy":
+ fuzzy = True
+ priority_func = self.prioritizer.name_count
+ else:
+ fuzzy = False
+ priority_func = self.prioritizer.keyword_count
+
+ # Construct a `_match` function for either fuzzy or non-fuzzy matching
+ # The match function returns a 2-tuple used for sorting the matches,
+ # or None if the item doesn't match
+ # Note: higher priority values mean more important, so use negative
+ # signs to flip the direction of the tuple
+ if fuzzy:
+ regex = ".*?".join(map(re.escape, text))
+ pat = re.compile("(%s)" % regex)
+
+ def _match(item):
+ if item.lower()[: len(text) + 1] in (text, text + " "):
+ # Exact match of first word in suggestion
+ # This is to get exact alias matches to the top
+ # E.g. for input `e`, 'Entries E' should be on top
+ # (before e.g. `EndUsers EU`)
+ return float("Infinity"), -1
+ r = pat.search(self.unescape_name(item.lower()))
+ if r:
+ return -len(r.group()), -r.start()
+
+ else:
+ match_end_limit = len(text)
+
+ def _match(item):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ # Use negative infinity to force keywords to sort after all
+ # fuzzy matches
+ return -float("Infinity"), -match_point
+
+ matches = []
+ for cand in collection:
+ if isinstance(cand, _Candidate):
+ item, prio, display_meta, synonyms, prio2, display = cand
+ if display_meta is None:
+ display_meta = meta
+ syn_matches = (_match(x) for x in synonyms)
+ # Nones need to be removed to avoid max() crashing in Python 3
+ syn_matches = [m for m in syn_matches if m]
+ sort_key = max(syn_matches) if syn_matches else None
+ else:
+ item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand
+ sort_key = _match(cand)
+
+ if sort_key:
+ if display_meta and len(display_meta) > 50:
+ # Truncate meta-text to 50 characters, if necessary
+ display_meta = display_meta[:47] + "..."
+
+ # Lexical order of items in the collection, used for
+ # tiebreaking items with the same match group length and start
+ # position. Since we use *higher* priority to mean "more
+ # important," we use -ord(c) to prioritize "aa" > "ab" and end
+ # with 1 to prioritize shorter strings (ie "user" > "users").
+ # We first do a case-insensitive sort and then a
+ # case-sensitive one as a tie breaker.
+ # We also use the unescape_name to make sure quoted names have
+ # the same priority as unquoted names.
+ lexical_priority = (
+ tuple(
+ 0 if c in " _" else -ord(c)
+ for c in self.unescape_name(item.lower())
+ )
+ + (1,)
+ + tuple(c for c in item)
+ )
+
+ item = self.case(item)
+ display = self.case(display)
+ priority = (
+ sort_key,
+ type_priority,
+ prio,
+ priority_func(item),
+ prio2,
+ lexical_priority,
+ )
+ matches.append(
+ Match(
+ completion=Completion(
+ text=item,
+ start_position=-text_len,
+ display_meta=display_meta,
+ display=display,
+ ),
+ priority=priority,
+ )
+ )
+ return matches
+
+ def case(self, word):
+ return self.casing.get(word, word)
+
+ def get_completions(self, document, complete_event, smart_completion=None):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+
+ if smart_completion is None:
+ smart_completion = self.smart_completion
+
+ # If smart_completion is off then match any word that starts with
+ # 'word_before_cursor'.
+ if not smart_completion:
+ matches = self.find_matches(
+ word_before_cursor, self.all_completions, mode="strict"
+ )
+ completions = [m.completion for m in matches]
+ return sorted(completions, key=operator.attrgetter("text"))
+
+ matches = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+ suggestion_type = type(suggestion)
+ _logger.debug("Suggestion type: %r", suggestion_type)
+
+ # Map suggestion type to method
+ # e.g. 'table' -> self.get_table_matches
+ matcher = self.suggestion_matchers[suggestion_type]
+ matches.extend(matcher(self, suggestion, word_before_cursor))
+
+ # Sort matches so highest priorities are first
+ matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True)
+
+ return [m.completion for m in matches]
+
+ def get_column_matches(self, suggestion, word_before_cursor):
+ tables = suggestion.table_refs
+ do_qualify = (
+ suggestion.qualifiable
+ and {
+ "always": True,
+ "never": False,
+ "if_more_than_one_table": len(tables) > 1,
+ }[self.qualify_columns]
+ )
+ qualify = lambda col, tbl: (
+ (tbl + "." + self.case(col)) if do_qualify else self.case(col)
+ )
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
+
+ def make_cand(name, ref):
+ synonyms = (name, generate_alias(self.case(name)))
+ return Candidate(qualify(name, ref), 0, "column", synonyms)
+
+ def flat_cols():
+ return [
+ make_cand(c.name, t.ref)
+ for t, cols in scoped_cols.items()
+ for c in cols
+ ]
+
+ if suggestion.require_last_table:
+ # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
+ # suggest only columns that appear in the last table and one more
+ ltbl = tables[-1].ref
+ other_tbl_cols = {
+ c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
+ }
+ scoped_cols = {
+ t: [col for col in cols if col.name in other_tbl_cols]
+ for t, cols in scoped_cols.items()
+ if t.ref == ltbl
+ }
+ lastword = last_word(word_before_cursor, include="most_punctuations")
+ if lastword == "*":
+ if suggestion.context == "insert":
+
+ def filter(col):
+ if not col.has_default:
+ return True
+ return not any(
+ p.match(col.default) for p in self.insert_col_skip_patterns
+ )
+
+ scoped_cols = {
+ t: [col for col in cols if filter(col)]
+ for t, cols in scoped_cols.items()
+ }
+ if self.asterisk_column_order == "alphabetic":
+ for cols in scoped_cols.values():
+ cols.sort(key=operator.attrgetter("name"))
+ if (
+ lastword != word_before_cursor
+ and len(tables) == 1
+ and word_before_cursor[-len(lastword) - 1] == "."
+ ):
+ # User typed x.*; replicate "x." for all columns except the
+ # first, which gets the original (as we only replace the "*"")
+ sep = ", " + word_before_cursor[:-1]
+ collist = sep.join(self.case(c.completion) for c in flat_cols())
+ else:
+ collist = ", ".join(
+ qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs
+ )
+
+ return [
+ Match(
+ completion=Completion(
+ collist, -1, display_meta="columns", display="*"
+ ),
+ priority=(1, 1, 1),
+ )
+ ]
+
+ return self.find_matches(word_before_cursor, flat_cols(), meta="column")
+
+ def alias(self, tbl, tbls):
+ """Generate a unique table alias
+ tbl - name of the table to alias, quoted if it needs to be
+ tbls - TableReference iterable of tables already in query
+ """
+ tbl = self.case(tbl)
+ tbls = {normalize_ref(t.ref) for t in tbls}
+ if self.generate_aliases:
+ tbl = generate_alias(self.unescape_name(tbl))
+ if normalize_ref(tbl) not in tbls:
+ return tbl
+ elif tbl[0] == '"':
+ aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2))
+ else:
+ aliases = (tbl + str(i) for i in count(2))
+ return next(a for a in aliases if normalize_ref(a) not in tbls)
+
+ def get_join_matches(self, suggestion, word_before_cursor):
+ tbls = suggestion.table_refs
+ cols = self.populate_scoped_cols(tbls)
+ # Set up some data structures for efficient access
+ qualified = {normalize_ref(t.ref): t.schema for t in tbls}
+ ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
+ refs = {normalize_ref(t.ref) for t in tbls}
+ other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
+ joins = []
+ # Iterate over FKs in existing tables to find potential joins
+ fks = (
+ (fk, rtbl, rcol)
+ for rtbl, rcols in cols.items()
+ for rcol in rcols
+ for fk in rcol.foreignkeys
+ )
+ col = namedtuple("col", "schema tbl col")
+ for fk, rtbl, rcol in fks:
+ right = col(rtbl.schema, rtbl.name, rcol.name)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left = child if parent == right else parent
+ if suggestion.schema and left.schema != suggestion.schema:
+ continue
+ c = self.case
+ if self.generate_aliases or normalize_ref(left.tbl) in refs:
+ lref = self.alias(left.tbl, suggestion.table_refs)
+ join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref
+ )
+ else:
+ join = "{0} ON {0}.{1} = {2}.{3}".format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col)
+ )
+ alias = generate_alias(self.case(left.tbl))
+ synonyms = [
+ join,
+ "{0} ON {0}.{1} = {2}.{3}".format(
+ alias, c(left.col), rtbl.ref, c(right.col)
+ ),
+ ]
+ # Schema-qualify if (1) new table in same schema as old, and old
+ # is schema-qualified, or (2) new in other schema, except public
+ if not suggestion.schema and (
+ qualified[normalize_ref(rtbl.ref)]
+ and left.schema == right.schema
+ or left.schema not in (right.schema, "public")
+ ):
+ join = left.schema + "." + join
+ prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
+ 0 if (left.schema, left.tbl) in other_tbls else 1
+ )
+ joins.append(Candidate(join, prio, "join", synonyms=synonyms))
+
+ return self.find_matches(word_before_cursor, joins, meta="join")
+
+ def get_join_condition_matches(self, suggestion, word_before_cursor):
+ col = namedtuple("col", "schema tbl col")
+ tbls = self.populate_scoped_cols(suggestion.table_refs).items
+ cols = [(t, c) for t, cs in tbls() for c in cs]
+ try:
+ lref = (suggestion.parent or suggestion.table_refs[-1]).ref
+ ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
+ except IndexError: # The user typed an incorrect table qualifier
+ return []
+ conds, found_conds = [], set()
+
+ def add_cond(lcol, rcol, rref, prio, meta):
+ prefix = "" if suggestion.parent else ltbl.ref + "."
+ case = self.case
+ cond = prefix + case(lcol) + " = " + rref + "." + case(rcol)
+ if cond not in found_conds:
+ found_conds.add(cond)
+ conds.append(Candidate(cond, prio + ref_prio[rref], meta))
+
+ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
+ d = defaultdict(list)
+ for pair in pairs:
+ d[pair[0]].append(pair[1])
+ return d
+
+ # Tables that are closer to the cursor get higher prio
+ ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
+ # Map (schema, table, col) to tables
+ coldict = list_dict(
+ ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
+ )
+ # For each fk from the left table, generate a join condition if
+ # the other table is also in the scope
+ fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
+ for fk, lcol in fks:
+ left = col(ltbl.schema, ltbl.name, lcol)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left, right = (child, par) if left == child else (par, child)
+ for rtbl in coldict[right]:
+ add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
+ # For name matching, use a {(colname, coltype): TableReference} dict
+ coltyp = namedtuple("coltyp", "name datatype")
+ col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
+ # Find all name-match join conditions
+ for c in (coltyp(c.name, c.datatype) for c in lcols):
+ for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
+ prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0
+ add_cond(c.name, c.name, rtbl.ref, prio, "name join")
+
+ return self.find_matches(word_before_cursor, conds, meta="join")
+
+ def get_function_matches(self, suggestion, word_before_cursor, alias=False):
+ if suggestion.usage == "from":
+ # Only suggest functions allowed in FROM clause
+
+ def filt(f):
+ return (
+ not f.is_aggregate
+ and not f.is_window
+ and not f.is_extension
+ and (
+ f.is_public
+ or f.schema_name in self.search_path
+ or f.schema_name == suggestion.schema
+ )
+ )
+
+ else:
+ alias = False
+
+ def filt(f):
+ return not f.is_extension and (
+ f.is_public or f.schema_name == suggestion.schema
+ )
+
+ arg_mode = {"signature": "signature", "special": None}.get(
+ suggestion.usage, "call"
+ )
+
+ # Function overloading means we way have multiple functions of the same
+ # name at this point, so keep unique names only
+ all_functions = self.populate_functions(suggestion.schema, filt)
+ funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
+
+ matches = self.find_matches(word_before_cursor, funcs, meta="function")
+
+ if not suggestion.schema and not suggestion.usage:
+ # also suggest hardcoded functions using startswith matching
+ predefined_funcs = self.find_matches(
+ word_before_cursor, self.functions, mode="strict", meta="function"
+ )
+ matches.extend(predefined_funcs)
+
+ return matches
+
+ def get_schema_matches(self, suggestion, word_before_cursor):
+ schema_names = self.dbmetadata["tables"].keys()
+
+ # Unless we're sure the user really wants them, hide schema names
+ # starting with pg_, which are mostly temporary schemas
+ if not word_before_cursor.startswith("pg_"):
+ schema_names = [s for s in schema_names if not s.startswith("pg_")]
+
+ if suggestion.quoted:
+ schema_names = [self.escape_schema(s) for s in schema_names]
+
+ return self.find_matches(word_before_cursor, schema_names, meta="schema")
+
+ def get_from_clause_item_matches(self, suggestion, word_before_cursor):
+ alias = self.generate_aliases
+ s = suggestion
+ t_sug = Table(s.schema, s.table_refs, s.local_tables)
+ v_sug = View(s.schema, s.table_refs)
+ f_sug = Function(s.schema, s.table_refs, usage="from")
+ return (
+ self.get_table_matches(t_sug, word_before_cursor, alias)
+ + self.get_view_matches(v_sug, word_before_cursor, alias)
+ + self.get_function_matches(f_sug, word_before_cursor, alias)
+ )
+
+ def _arg_list(self, func, usage):
+ """Returns a an arg list string, e.g. `(_foo:=23)` for a func.
+
+ :param func is a FunctionMetadata object
+ :param usage is 'call', 'call_display' or 'signature'
+
+ """
+ template = {
+ "call": self.call_arg_style,
+ "call_display": self.call_arg_display_style,
+ "signature": self.signature_arg_style,
+ }[usage]
+ args = func.args()
+ if not template:
+ return "()"
+ elif usage == "call" and len(args) < 2:
+ return "()"
+ elif usage == "call" and func.has_variadic():
+ return "()"
+ multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
+ max_arg_len = max(len(a.name) for a in args) if multiline else 0
+ args = (
+ self._format_arg(template, arg, arg_num + 1, max_arg_len)
+ for arg_num, arg in enumerate(args)
+ )
+ if multiline:
+ return "(" + ",".join("\n " + a for a in args if a) + "\n)"
+ else:
+ return "(" + ", ".join(a for a in args if a) + ")"
+
+ def _format_arg(self, template, arg, arg_num, max_arg_len):
+ if not template:
+ return None
+ if arg.has_default:
+ arg_default = "NULL" if arg.default is None else arg.default
+ # Remove trailing ::(schema.)type
+ arg_default = arg_default_type_strip_regex.sub("", arg_default)
+ else:
+ arg_default = ""
+ return template.format(
+ max_arg_len=max_arg_len,
+ arg_name=self.case(arg.name),
+ arg_num=arg_num,
+ arg_type=arg.datatype,
+ arg_default=arg_default,
+ )
+
+ def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
+ """Returns a Candidate namedtuple.
+
+ :param tbl is a SchemaObject
+ :param arg_mode determines what type of arg list to suffix for functions.
+ Possible values: call, signature
+
+ """
+ cased_tbl = self.case(tbl.name)
+ if do_alias:
+ alias = self.alias(cased_tbl, suggestion.table_refs)
+ synonyms = (cased_tbl, generate_alias(cased_tbl))
+ maybe_alias = (" " + alias) if do_alias else ""
+ maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else ""
+ suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
+ if arg_mode == "call":
+ display_suffix = self._arg_list_cache["call_display"][tbl.meta]
+ elif arg_mode == "signature":
+ display_suffix = self._arg_list_cache["signature"][tbl.meta]
+ else:
+ display_suffix = ""
+ item = maybe_schema + cased_tbl + suffix + maybe_alias
+ display = maybe_schema + cased_tbl + display_suffix + maybe_alias
+ prio2 = 0 if tbl.schema else 1
+ return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
+
+ def get_table_matches(self, suggestion, word_before_cursor, alias=False):
+ tables = self.populate_schema_objects(suggestion.schema, "tables")
+ tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)
+
+ # Unless we're sure the user really wants them, don't suggest the
+ # pg_catalog tables that are implicitly on the search path
+ if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
+ tables = [t for t in tables if not t.name.startswith("pg_")]
+ tables = [self._make_cand(t, alias, suggestion) for t in tables]
+ return self.find_matches(word_before_cursor, tables, meta="table")
+
+ def get_table_formats(self, _, word_before_cursor):
+ formats = TabularOutputFormatter().supported_formats
+ return self.find_matches(word_before_cursor, formats, meta="table format")
+
+ def get_view_matches(self, suggestion, word_before_cursor, alias=False):
+ views = self.populate_schema_objects(suggestion.schema, "views")
+
+ if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
+ views = [v for v in views if not v.name.startswith("pg_")]
+ views = [self._make_cand(v, alias, suggestion) for v in views]
+ return self.find_matches(word_before_cursor, views, meta="view")
+
+ def get_alias_matches(self, suggestion, word_before_cursor):
+ aliases = suggestion.aliases
+ return self.find_matches(word_before_cursor, aliases, meta="table alias")
+
+ def get_database_matches(self, _, word_before_cursor):
+ return self.find_matches(word_before_cursor, self.databases, meta="database")
+
+ def get_keyword_matches(self, suggestion, word_before_cursor):
+ keywords = self.keywords_tree.keys()
+ # Get well known following keywords for the last token. If any, narrow
+ # candidates to this list.
+ next_keywords = self.keywords_tree.get(suggestion.last_token, [])
+ if next_keywords:
+ keywords = next_keywords
+
+ casing = self.keyword_casing
+ if casing == "auto":
+ if word_before_cursor and word_before_cursor[-1].islower():
+ casing = "lower"
+ else:
+ casing = "upper"
+
+ if casing == "upper":
+ keywords = [k.upper() for k in keywords]
+ else:
+ keywords = [k.lower() for k in keywords]
+
+ return self.find_matches(
+ word_before_cursor, keywords, mode="strict", meta="keyword"
+ )
+
+ def get_path_matches(self, _, word_before_cursor):
+ completer = PathCompleter(expanduser=True)
+ document = Document(
+ text=word_before_cursor, cursor_position=len(word_before_cursor)
+ )
+ for c in completer.get_completions(document, None):
+ yield Match(completion=c, priority=(0,))
+
+ def get_special_matches(self, _, word_before_cursor):
+ if not self.pgspecial:
+ return []
+
+ commands = self.pgspecial.commands
+ cmds = commands.keys()
+ cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds]
+ return self.find_matches(word_before_cursor, cmds, mode="strict")
+
+ def get_datatype_matches(self, suggestion, word_before_cursor):
+ # suggest custom datatypes
+ types = self.populate_schema_objects(suggestion.schema, "datatypes")
+ types = [self._make_cand(t, False, suggestion) for t in types]
+ matches = self.find_matches(word_before_cursor, types, meta="datatype")
+
+ if not suggestion.schema:
+ # Also suggest hardcoded types
+ matches.extend(
+ self.find_matches(
+ word_before_cursor, self.datatypes, mode="strict", meta="datatype"
+ )
+ )
+
+ return matches
+
+ def get_namedquery_matches(self, _, word_before_cursor):
+ return self.find_matches(
+ word_before_cursor, NamedQueries.instance.list(), meta="named query"
+ )
+
+ suggestion_matchers = {
+ FromClauseItem: get_from_clause_item_matches,
+ JoinCondition: get_join_condition_matches,
+ Join: get_join_matches,
+ Column: get_column_matches,
+ Function: get_function_matches,
+ Schema: get_schema_matches,
+ Table: get_table_matches,
+ TableFormat: get_table_formats,
+ View: get_view_matches,
+ Alias: get_alias_matches,
+ Database: get_database_matches,
+ Keyword: get_keyword_matches,
+ Special: get_special_matches,
+ Datatype: get_datatype_matches,
+ NamedQuery: get_namedquery_matches,
+ Path: get_path_matches,
+ }
+
+ def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
+ """Find all columns in a set of scoped_tables.
+
+ :param scoped_tbls: list of TableReference namedtuples
+ :param local_tbls: tuple(TableMetadata)
+ :return: {TableReference:{colname:ColumnMetaData}}
+
+ """
+ ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
+ columns = OrderedDict()
+ meta = self.dbmetadata
+
+ def addcols(schema, rel, alias, reltype, cols):
+ tbl = TableReference(schema, rel, alias, reltype == "functions")
+ if tbl not in columns:
+ columns[tbl] = []
+ columns[tbl].extend(cols)
+
+ for tbl in scoped_tbls:
+ # Local tables should shadow database tables
+ if tbl.schema is None and normalize_ref(tbl.name) in ctes:
+ cols = ctes[normalize_ref(tbl.name)]
+ addcols(None, tbl.name, "CTE", tbl.alias, cols)
+ continue
+ schemas = [tbl.schema] if tbl.schema else self.search_path
+ for schema in schemas:
+ relname = self.escape_name(tbl.name)
+ schema = self.escape_name(schema)
+ if tbl.is_function:
+ # Return column names from a set-returning function
+ # Get an array of FunctionMetadata objects
+ functions = meta["functions"].get(schema, {}).get(relname)
+ for func in functions or []:
+ # func is a FunctionMetadata object
+ cols = func.fields()
+ addcols(schema, relname, tbl.alias, "functions", cols)
+ else:
+ for reltype in ("tables", "views"):
+ cols = meta[reltype].get(schema, {}).get(relname)
+ if cols:
+ cols = cols.values()
+ addcols(schema, relname, tbl.alias, reltype, cols)
+ break
+
+ return columns
+
+ def _get_schemas(self, obj_typ, schema):
+ """Returns a list of schemas from which to suggest objects.
+
+ :param schema is the schema qualification input by the user (if any)
+
+ """
+ metadata = self.dbmetadata[obj_typ]
+ if schema:
+ schema = self.escape_name(schema)
+ return [schema] if schema in metadata else []
+ return self.search_path if self.search_path_filter else metadata.keys()
+
+ def _maybe_schema(self, schema, parent):
+ return None if parent or schema in self.search_path else schema
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns a list of SchemaObjects representing tables or views.
+
+ :param schema is the schema qualification input by the user (if any)
+
+ """
+
+ return [
+ SchemaObject(
+ name=obj, schema=(self._maybe_schema(schema=sch, parent=schema))
+ )
+ for sch in self._get_schemas(obj_type, schema)
+ for obj in self.dbmetadata[obj_type][sch].keys()
+ ]
+
+ def populate_functions(self, schema, filter_func):
+ """Returns a list of function SchemaObjects.
+
+ :param filter_func is a function that accepts a FunctionMetadata
+ namedtuple and returns a boolean indicating whether that
+ function should be kept or discarded
+
+ """
+
+ # Because of multiple dispatch, we can have multiple functions
+ # with the same name, which is why `for meta in metas` is necessary
+ # in the comprehensions below
+ return [
+ SchemaObject(
+ name=func,
+ schema=(self._maybe_schema(schema=sch, parent=schema)),
+ meta=meta,
+ )
+ for sch in self._get_schemas("functions", schema)
+ for (func, metas) in self.dbmetadata["functions"][sch].items()
+ for meta in metas
+ if filter_func(meta)
+ ]
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
new file mode 100644
index 0000000..497d681
--- /dev/null
+++ b/pgcli/pgexecute.py
@@ -0,0 +1,868 @@
+import logging
+import traceback
+from collections import namedtuple
+import re
+import pgspecial as special
+import psycopg
+import psycopg.sql
+from psycopg.conninfo import make_conninfo
+import sqlparse
+
+from .packages.parseutils.meta import FunctionMetadata, ForeignKey
+
+_logger = logging.getLogger(__name__)
+
+ViewDef = namedtuple(
+ "ViewDef", "nspname relname relkind viewdef reloptions checkoption"
+)
+
+
+# we added this funcion to strip beginning comments
+# because sqlparse didn't handle tem well. It won't be needed if sqlparse
+# does parsing of this situation better
+
+
+def remove_beginning_comments(command):
+ # Regular expression pattern to match comments
+ pattern = r"^(/\*.*?\*/|--.*?)(?:\n|$)"
+
+ # Find and remove all comments from the beginning
+ cleaned_command = command
+ comments = []
+ match = re.match(pattern, cleaned_command, re.DOTALL)
+ while match:
+ comments.append(match.group())
+ cleaned_command = cleaned_command[len(match.group()) :].lstrip()
+ match = re.match(pattern, cleaned_command, re.DOTALL)
+
+ return [cleaned_command, comments]
+
+
+def register_typecasters(connection):
+ """Casts date and timestamp values to string, resolves issues with out-of-range
+ dates (e.g. BC) which psycopg can't handle"""
+ for forced_text_type in [
+ "date",
+ "time",
+ "timestamp",
+ "timestamptz",
+ "bytea",
+ "json",
+ "jsonb",
+ ]:
+ connection.adapters.register_loader(
+ forced_text_type, psycopg.types.string.TextLoader
+ )
+
+
+# pg3: I don't know what is this
+class ProtocolSafeCursor(psycopg.Cursor):
+ """This class wraps and suppresses Protocol Errors with pgbouncer database.
+ See https://github.com/dbcli/pgcli/pull/1097.
+ Pgbouncer database is a virtual database with its own set of commands."""
+
+ def __init__(self, *args, **kwargs):
+ self.protocol_error = False
+ self.protocol_message = ""
+ super().__init__(*args, **kwargs)
+
+ def __iter__(self):
+ if self.protocol_error:
+ raise StopIteration
+ return super().__iter__()
+
+ def fetchall(self):
+ if self.protocol_error:
+ return [(self.protocol_message,)]
+ return super().fetchall()
+
+ def fetchone(self):
+ if self.protocol_error:
+ return (self.protocol_message,)
+ return super().fetchone()
+
+ # def mogrify(self, query, params):
+ # args = [Literal(v).as_string(self.connection) for v in params]
+ # return query % tuple(args)
+ #
+ def execute(self, *args, **kwargs):
+ try:
+ super().execute(*args, **kwargs)
+ self.protocol_error = False
+ self.protocol_message = ""
+ except psycopg.errors.ProtocolViolation as ex:
+ self.protocol_error = True
+ self.protocol_message = str(ex)
+ _logger.debug("%s: %s" % (ex.__class__.__name__, ex))
+
+
+class PGExecute:
+ # The boolean argument to the current_schemas function indicates whether
+ # implicit schemas, e.g. pg_catalog
+ search_path_query = """
+ SELECT * FROM unnest(current_schemas(true))"""
+
+ schemata_query = """
+ SELECT nspname
+ FROM pg_catalog.pg_namespace
+ ORDER BY 1 """
+
+ tables_query = """
+ SELECT n.nspname schema_name,
+ c.relname table_name
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n
+ ON n.oid = c.relnamespace
+ WHERE c.relkind = ANY(%s)
+ ORDER BY 1,2;"""
+
+ databases_query = """
+ SELECT d.datname
+ FROM pg_catalog.pg_database d
+ ORDER BY 1"""
+
+ full_databases_query = """
+ SELECT d.datname as "Name",
+ pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
+ pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding",
+ d.datcollate as "Collate",
+ d.datctype as "Ctype",
+ pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges"
+ FROM pg_catalog.pg_database d
+ ORDER BY 1"""
+
+ socket_directory_query = """
+ SELECT setting
+ FROM pg_settings
+ WHERE name = 'unix_socket_directories'
+ """
+
+ view_definition_query = """
+ WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid)
+ SELECT nspname, relname, relkind,
+ pg_catalog.pg_get_viewdef(c.oid, true),
+ array_remove(array_remove(c.reloptions,'check_option=local'),
+ 'check_option=cascaded') AS reloptions,
+ CASE
+ WHEN 'check_option=local' = ANY (c.reloptions) THEN 'LOCAL'::text
+ WHEN 'check_option=cascaded' = ANY (c.reloptions) THEN 'CASCADED'::text
+ ELSE NULL
+ END AS checkoption
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid)
+ JOIN v ON (c.oid = v.v_oid)"""
+
+ function_definition_query = """
+ WITH f AS
+ (SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid)
+ SELECT pg_catalog.pg_get_functiondef(f.f_oid)
+ FROM f"""
+
+ def __init__(
+ self,
+ database=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ dsn=None,
+ **kwargs,
+ ):
+ self._conn_params = {}
+ self._is_virtual_database = None
+ self.conn = None
+ self.dbname = None
+ self.user = None
+ self.password = None
+ self.host = None
+ self.port = None
+ self.server_version = None
+ self.extra_args = None
+ self.connect(database, user, password, host, port, dsn, **kwargs)
+ self.reset_expanded = None
+
+ def is_virtual_database(self):
+ if self._is_virtual_database is None:
+ self._is_virtual_database = self.is_protocol_error()
+ return self._is_virtual_database
+
+ def copy(self):
+ """Returns a clone of the current executor."""
+ return self.__class__(**self._conn_params)
+
+ def connect(
+ self,
+ database=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ dsn=None,
+ **kwargs,
+ ):
+ conn_params = self._conn_params.copy()
+
+ new_params = {
+ "dbname": database,
+ "user": user,
+ "password": password,
+ "host": host,
+ "port": port,
+ "dsn": dsn,
+ }
+ new_params.update(kwargs)
+
+ if new_params["dsn"]:
+ new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
+
+ if new_params["password"]:
+ new_params["dsn"] = make_conninfo(
+ new_params["dsn"], password=new_params.pop("password")
+ )
+
+ conn_params.update({k: v for k, v in new_params.items() if v})
+
+ if "dsn" in conn_params:
+ other_params = {k: v for k, v in conn_params.items() if k != "dsn"}
+ conn_info = make_conninfo(conn_params["dsn"], **other_params)
+ else:
+ conn_info = make_conninfo(**conn_params)
+ conn = psycopg.connect(conn_info)
+ conn.cursor_factory = ProtocolSafeCursor
+
+ self._conn_params = conn_params
+ if self.conn:
+ self.conn.close()
+ self.conn = conn
+ self.conn.autocommit = True
+
+ # When we connect using a DSN, we don't really know what db,
+ # user, etc. we connected to. Let's read it.
+ # Note: moved this after setting autocommit because of #664.
+ dsn_parameters = conn.info.get_parameters()
+
+ if dsn_parameters:
+ self.dbname = dsn_parameters.get("dbname")
+ self.user = dsn_parameters.get("user")
+ self.host = dsn_parameters.get("host")
+ self.port = dsn_parameters.get("port")
+ else:
+ self.dbname = conn_params.get("database")
+ self.user = conn_params.get("user")
+ self.host = conn_params.get("host")
+ self.port = conn_params.get("port")
+
+ self.password = password
+ self.extra_args = kwargs
+
+ if not self.host:
+ self.host = (
+ "pgbouncer"
+ if self.is_virtual_database()
+ else self.get_socket_directory()
+ )
+
+ self.pid = conn.info.backend_pid
+ self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1")
+ self.server_version = conn.info.parameter_status("server_version") or ""
+
+ # _set_wait_callback(self.is_virtual_database())
+
+ if not self.is_virtual_database():
+ register_typecasters(conn)
+
+ @property
+ def short_host(self):
+ if "," in self.host:
+ host, _, _ = self.host.partition(",")
+ else:
+ host = self.host
+ short_host, _, _ = host.partition(".")
+ return short_host
+
+ def _select_one(self, cur, sql):
+ """
+ Helper method to run a select and retrieve a single field value
+ :param cur: cursor
+ :param sql: string
+ :return: string
+ """
+ cur.execute(sql)
+ return cur.fetchone()
+
+ def failed_transaction(self):
+ return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
+
+ def valid_transaction(self):
+ status = self.conn.info.transaction_status
+ return (
+ status == psycopg.pq.TransactionStatus.ACTIVE
+ or status == psycopg.pq.TransactionStatus.INTRANS
+ )
+
+ def run(
+ self,
+ statement,
+ pgspecial=None,
+ exception_formatter=None,
+ on_error_resume=False,
+ explain_mode=False,
+ ):
+ """Execute the sql in the database and return the results.
+
+ :param statement: A string containing one or more sql statements
+ :param pgspecial: PGSpecial object
+ :param exception_formatter: A callable that accepts an Exception and
+ returns a formatted (title, rows, headers, status) tuple that can
+ act as a query result. If an exception_formatter is not supplied,
+ psycopg2 exceptions are always raised.
+ :param on_error_resume: Bool. If true, queries following an exception
+ (assuming exception_formatter has been supplied) continue to
+ execute.
+
+ :return: Generator yielding tuples containing
+ (title, rows, headers, status, query, success, is_special)
+ """
+
+ # Remove spaces and EOL
+ statement = statement.strip()
+ if not statement: # Empty string
+ yield None, None, None, None, statement, False, False
+
+ # sql parse doesn't split on a comment first + special
+ # so we're going to do it
+
+ removed_comments = []
+ sqlarr = []
+ cleaned_command = ""
+
+ # could skip if statement doesn't match ^-- or ^/*
+ cleaned_command, removed_comments = remove_beginning_comments(statement)
+
+ sqlarr = sqlparse.split(cleaned_command)
+
+ # now re-add the beginning comments if there are any, so that they show up in
+ # log files etc when running these commands
+
+ if len(removed_comments) > 0:
+ sqlarr = removed_comments + sqlarr
+
+ # run each sql query
+ for sql in sqlarr:
+ # Remove spaces, eol and semi-colons.
+ sql = sql.rstrip(";")
+ sql = sqlparse.format(sql, strip_comments=False).strip()
+ if not sql:
+ continue
+ try:
+ if explain_mode:
+ sql = self.explain_prefix() + sql
+ elif pgspecial:
+ # \G is treated specially since we have to set the expanded output.
+ if sql.endswith("\\G"):
+ if not pgspecial.expanded_output:
+ pgspecial.expanded_output = True
+ self.reset_expanded = True
+ sql = sql[:-2].strip()
+
+ # First try to run each query as special
+ _logger.debug("Trying a pgspecial command. sql: %r", sql)
+ try:
+ cur = self.conn.cursor()
+ except psycopg.InterfaceError:
+ # edge case when connection is already closed, but we
+ # don't need cursor for special_cmd.arg_type == NO_QUERY.
+ # See https://github.com/dbcli/pgcli/issues/1014.
+ cur = None
+ try:
+ response = pgspecial.execute(cur, sql)
+ if cur and cur.protocol_error:
+ yield None, None, None, cur.protocol_message, statement, False, False
+ # this would close connection. We should reconnect.
+ self.connect()
+ continue
+ for result in response:
+ # e.g. execute_from_file already appends these
+ if len(result) < 7:
+ yield result + (sql, True, True)
+ else:
+ yield result
+ continue
+ except special.CommandNotFound:
+ pass
+
+ # Not a special command, so execute as normal sql
+ yield self.execute_normal_sql(sql) + (sql, True, False)
+ except psycopg.DatabaseError as e:
+ _logger.error("sql: %r, error: %r", sql, e)
+ _logger.error("traceback: %r", traceback.format_exc())
+
+ if self._must_raise(e) or not exception_formatter:
+ raise
+
+ yield None, None, None, exception_formatter(e), sql, False, False
+
+ if not on_error_resume:
+ break
+ finally:
+ if self.reset_expanded:
+ pgspecial.expanded_output = False
+ self.reset_expanded = None
+
+ def _must_raise(self, e):
+ """Return true if e is an error that should not be caught in ``run``.
+
+ An uncaught error will prompt the user to reconnect; as long as we
+ detect that the connection is still open, we catch the error, as
+ reconnecting won't solve that problem.
+
+ :param e: DatabaseError. An exception raised while executing a query.
+
+ :return: Bool. True if ``run`` must raise this exception.
+
+ """
+ return self.conn.closed != 0
+
+ def execute_normal_sql(self, split_sql):
+ """Returns tuple (title, rows, headers, status)"""
+ _logger.debug("Regular sql statement. sql: %r", split_sql)
+
+ title = ""
+
+ def handle_notices(n):
+ nonlocal title
+ title = f"{n.message_primary}\n{n.message_detail}\n{title}"
+
+ self.conn.add_notice_handler(handle_notices)
+
+ if self.is_virtual_database() and "show help" in split_sql.lower():
+ # see https://github.com/psycopg/psycopg/issues/303
+ # special case "show help" in pgbouncer
+ res = self.conn.pgconn.exec_(split_sql.encode())
+ return title, None, None, res.command_status.decode()
+
+ cur = self.conn.cursor()
+ cur.execute(split_sql)
+
+ # cur.description will be None for operations that do not return
+ # rows.
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return title, cur, headers, cur.statusmessage
+ elif cur.protocol_error:
+ _logger.debug("Protocol error, unsupported command.")
+ return title, None, None, cur.protocol_message
+ else:
+ _logger.debug("No rows in result.")
+ return title, None, None, cur.statusmessage
+
+ def search_path(self):
+ """Returns the current search path as a list of schema names"""
+
+ try:
+ with self.conn.cursor() as cur:
+ _logger.debug("Search path query. sql: %r", self.search_path_query)
+ cur.execute(self.search_path_query)
+ return [x[0] for x in cur.fetchall()]
+ except psycopg.ProgrammingError:
+ fallback = "SELECT * FROM current_schemas(true)"
+ with self.conn.cursor() as cur:
+ _logger.debug("Search path query. sql: %r", fallback)
+ cur.execute(fallback)
+ return cur.fetchone()[0]
+
+ def view_definition(self, spec):
+ """Returns the SQL defining views described by `spec`"""
+
+ # 2: relkind, v or m (materialized)
+ # 4: reloptions, null
+ # 5: checkoption: local or cascaded
+ with self.conn.cursor() as cur:
+ sql = self.view_definition_query
+ _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
+ try:
+ cur.execute(sql, (spec,))
+ except psycopg.ProgrammingError:
+ raise RuntimeError(f"View {spec} does not exist.")
+ result = ViewDef(*cur.fetchone())
+ if result.relkind == "m":
+ template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}"
+ else:
+ template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}"
+ return (
+ psycopg.sql.SQL(template)
+ .format(
+ name=psycopg.sql.Identifier(result.nspname, result.relname),
+ stmt=psycopg.sql.SQL(result.viewdef),
+ )
+ .as_string(self.conn)
+ )
+
+ def function_definition(self, spec):
+ """Returns the SQL defining functions described by `spec`"""
+
+ with self.conn.cursor() as cur:
+ sql = self.function_definition_query
+ _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec)
+ try:
+ cur.execute(sql, (spec,))
+ result = cur.fetchone()
+ return result[0]
+ except psycopg.ProgrammingError:
+ raise RuntimeError(f"Function {spec} does not exist.")
+
+ def schemata(self):
+ """Returns a list of schema names in the database"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug("Schemata Query. sql: %r", self.schemata_query)
+ cur.execute(self.schemata_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def _relations(self, kinds=("r", "p", "f", "v", "m")):
+ """Get table or view name metadata
+
+ :param kinds: list of postgres relkind filters:
+ 'r' - table
+ 'p' - partitioned table
+ 'f' - foreign table
+ 'v' - view
+ 'm' - materialized view
+ :return: (schema_name, rel_name) tuples
+ """
+
+ with self.conn.cursor() as cur:
+ # sql = cur.mogrify(self.tables_query, kinds)
+ # _logger.debug("Tables Query. sql: %r", sql)
+ cur.execute(self.tables_query, [kinds])
+ yield from cur
+
+ def tables(self):
+ """Yields (schema_name, table_name) tuples"""
+ yield from self._relations(kinds=["r", "p", "f"])
+
+ def views(self):
+ """Yields (schema_name, view_name) tuples.
+
+ Includes both views and and materialized views
+ """
+ yield from self._relations(kinds=["v", "m"])
+
+ def _columns(self, kinds=("r", "p", "f", "v", "m")):
+ """Get column metadata for tables and views
+
+ :param kinds: kinds: list of postgres relkind filters:
+ 'r' - table
+ 'p' - partitioned table
+ 'f' - foreign table
+ 'v' - view
+ 'm' - materialized view
+ :return: list of (schema_name, relation_name, column_name, column_type) tuples
+ """
+
+ if self.conn.info.server_version >= 80400:
+ columns_query = """
+ SELECT nsp.nspname schema_name,
+ cls.relname table_name,
+ att.attname column_name,
+ att.atttypid::regtype::text type_name,
+ att.atthasdef AS has_default,
+ pg_catalog.pg_get_expr(def.adbin, def.adrelid, true) as default
+ FROM pg_catalog.pg_attribute att
+ INNER JOIN pg_catalog.pg_class cls
+ ON att.attrelid = cls.oid
+ INNER JOIN pg_catalog.pg_namespace nsp
+ ON cls.relnamespace = nsp.oid
+ LEFT OUTER JOIN pg_attrdef def
+ ON def.adrelid = att.attrelid
+ AND def.adnum = att.attnum
+ WHERE cls.relkind = ANY(%s)
+ AND NOT att.attisdropped
+ AND att.attnum > 0
+ ORDER BY 1, 2, att.attnum"""
+ else:
+ columns_query = """
+ SELECT nsp.nspname schema_name,
+ cls.relname table_name,
+ att.attname column_name,
+ typ.typname type_name,
+ NULL AS has_default,
+ NULL AS default
+ FROM pg_catalog.pg_attribute att
+ INNER JOIN pg_catalog.pg_class cls
+ ON att.attrelid = cls.oid
+ INNER JOIN pg_catalog.pg_namespace nsp
+ ON cls.relnamespace = nsp.oid
+ INNER JOIN pg_catalog.pg_type typ
+ ON typ.oid = att.atttypid
+ WHERE cls.relkind = ANY(%s)
+ AND NOT att.attisdropped
+ AND att.attnum > 0
+ ORDER BY 1, 2, att.attnum"""
+
+ with self.conn.cursor() as cur:
+ # sql = cur.mogrify(columns_query, kinds)
+ # _logger.debug("Columns Query. sql: %r", sql)
+ cur.execute(columns_query, [kinds])
+ yield from cur
+
+ def table_columns(self):
+ yield from self._columns(kinds=["r", "p", "f"])
+
+ def view_columns(self):
+ yield from self._columns(kinds=["v", "m"])
+
+ def databases(self):
+ with self.conn.cursor() as cur:
+ _logger.debug("Databases Query. sql: %r", self.databases_query)
+ cur.execute(self.databases_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def full_databases(self):
+ with self.conn.cursor() as cur:
+ _logger.debug("Databases Query. sql: %r", self.full_databases_query)
+ cur.execute(self.full_databases_query)
+ headers = [x[0] for x in cur.description]
+ return cur.fetchall(), headers, cur.statusmessage
+
+ def is_protocol_error(self):
+ query = "SELECT 1"
+ with self.conn.cursor() as cur:
+ _logger.debug("Simple Query. sql: %r", query)
+ cur.execute(query)
+ return bool(cur.protocol_error)
+
+ def get_socket_directory(self):
+ with self.conn.cursor() as cur:
+ _logger.debug(
+ "Socket directory Query. sql: %r", self.socket_directory_query
+ )
+ cur.execute(self.socket_directory_query)
+ result = cur.fetchone()
+ return result[0] if result else ""
+
+ def foreignkeys(self):
+ """Yields ForeignKey named tuples"""
+
+ if self.conn.info.server_version < 90000:
+ return
+
+ with self.conn.cursor() as cur:
+ query = """
+ SELECT s_p.nspname AS parentschema,
+ t_p.relname AS parenttable,
+ unnest((
+ select
+ array_agg(attname ORDER BY i)
+ from
+ (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
+ JOIN pg_catalog.pg_attribute c USING(attnum)
+ WHERE c.attrelid = fk.confrelid
+ )) AS parentcolumn,
+ s_c.nspname AS childschema,
+ t_c.relname AS childtable,
+ unnest((
+ select
+ array_agg(attname ORDER BY i)
+ from
+ (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
+ JOIN pg_catalog.pg_attribute c USING(attnum)
+ WHERE c.attrelid = fk.conrelid
+ )) AS childcolumn
+ FROM pg_catalog.pg_constraint fk
+ JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
+ JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
+ JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
+ JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
+ WHERE fk.contype = 'f';
+ """
+ _logger.debug("Functions Query. sql: %r", query)
+ cur.execute(query)
+ for row in cur:
+ yield ForeignKey(*row)
+
+ def functions(self):
+ """Yields FunctionMetadata named tuples"""
+
+ if self.conn.info.server_version >= 110000:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text return_type,
+ p.prokind = 'a' is_aggregate,
+ p.prokind = 'w' is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ pg_get_expr(proargdefaults, 0) AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+ elif self.conn.info.server_version > 90000:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text return_type,
+ p.proisagg is_aggregate,
+ p.proiswindow is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ pg_get_expr(proargdefaults, 0) AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+ elif self.conn.info.server_version >= 80400:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text,
+ p.proisagg is_aggregate,
+ false is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ NULL AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+ else:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ NULL arg_types,
+ NULL arg_modes,
+ '' ret_type,
+ p.proisagg is_aggregate,
+ false is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ NULL AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+
+ with self.conn.cursor() as cur:
+ _logger.debug("Functions Query. sql: %r", query)
+ cur.execute(query)
+ for row in cur:
+ yield FunctionMetadata(*row)
+
+ def datatypes(self):
+ """Yields tuples of (schema_name, type_name)"""
+
+ with self.conn.cursor() as cur:
+ if self.conn.info.server_version > 90000:
+ query = """
+ SELECT n.nspname schema_name,
+ t.typname type_name
+ FROM pg_catalog.pg_type t
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = t.typnamespace
+ WHERE ( t.typrelid = 0 -- non-composite types
+ OR ( -- composite type, but not a table
+ SELECT c.relkind = 'c'
+ FROM pg_catalog.pg_class c
+ WHERE c.oid = t.typrelid
+ )
+ )
+ AND NOT EXISTS( -- ignore array types
+ SELECT 1
+ FROM pg_catalog.pg_type el
+ WHERE el.oid = t.typelem AND el.typarray = t.oid
+ )
+ AND n.nspname <> 'pg_catalog'
+ AND n.nspname <> 'information_schema'
+ ORDER BY 1, 2;
+ """
+ else:
+ query = """
+ SELECT n.nspname schema_name,
+ pg_catalog.format_type(t.oid, NULL) type_name
+ FROM pg_catalog.pg_type t
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
+ AND t.typname !~ '^_'
+ AND n.nspname <> 'pg_catalog'
+ AND n.nspname <> 'information_schema'
+ AND pg_catalog.pg_type_is_visible(t.oid)
+ ORDER BY 1, 2;
+ """
+ _logger.debug("Datatypes Query. sql: %r", query)
+ cur.execute(query)
+ yield from cur
+
+ def casing(self):
+ """Yields the most common casing for names used in db functions"""
+ with self.conn.cursor() as cur:
+ query = r"""
+ WITH Words AS (
+ SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1)
+ FROM pg_catalog.pg_proc P
+ JOIN pg_catalog.pg_namespace N ON N.oid = P.pronamespace
+ JOIN pg_catalog.pg_language L ON L.oid = P.prolang
+ WHERE L.lanname IN ('sql', 'plpgsql')
+ AND N.nspname NOT IN ('pg_catalog', 'information_schema')
+ GROUP BY Word
+ ),
+ OrderWords AS (
+ SELECT Word,
+ ROW_NUMBER() OVER(PARTITION BY LOWER(Word) ORDER BY Count DESC)
+ FROM Words
+ WHERE Word ~* '.*[a-z].*'
+ ),
+ Names AS (
+ --Column names
+ SELECT attname AS Name
+ FROM pg_catalog.pg_attribute
+ UNION -- Table/view names
+ SELECT relname
+ FROM pg_catalog.pg_class
+ UNION -- Function names
+ SELECT proname
+ FROM pg_catalog.pg_proc
+ UNION -- Type names
+ SELECT typname
+ FROM pg_catalog.pg_type
+ UNION -- Schema names
+ SELECT nspname
+ FROM pg_catalog.pg_namespace
+ UNION -- Parameter names
+ SELECT unnest(proargnames)
+ FROM pg_proc
+ )
+ SELECT Word
+ FROM OrderWords
+ WHERE LOWER(Word) IN (SELECT Name FROM Names)
+ AND Row_Number = 1;
+ """
+ _logger.debug("Casing Query. sql: %r", query)
+ cur.execute(query)
+ for row in cur:
+ yield row[0]
+
+ def explain_prefix(self):
+ return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) "
diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py
new file mode 100644
index 0000000..77874f4
--- /dev/null
+++ b/pgcli/pgstyle.py
@@ -0,0 +1,116 @@
+import logging
+
+import pygments.styles
+from pygments.token import string_to_tokentype, Token
+from pygments.style import Style as PygmentsStyle
+from pygments.util import ClassNotFound
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.styles import merge_styles, Style
+
+logger = logging.getLogger(__name__)
+
+# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
+TOKEN_TO_PROMPT_STYLE = {
+ Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
+ Token.Menu.Completions.Completion: "completion-menu.completion",
+ Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
+ Token.Menu.Completions.Meta: "completion-menu.meta.completion",
+ Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
+ Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
+ Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
+ Token.SelectedText: "selected",
+ Token.SearchMatch: "search",
+ Token.SearchMatch.Current: "search.current",
+ Token.Toolbar: "bottom-toolbar",
+ Token.Toolbar.Off: "bottom-toolbar.off",
+ Token.Toolbar.On: "bottom-toolbar.on",
+ Token.Toolbar.Search: "search-toolbar",
+ Token.Toolbar.Search.Text: "search-toolbar.text",
+ Token.Toolbar.System: "system-toolbar",
+ Token.Toolbar.Arg: "arg-toolbar",
+ Token.Toolbar.Arg.Text: "arg-toolbar.text",
+ Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
+ Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
+ Token.Output.Header: "output.header",
+ Token.Output.OddRow: "output.odd-row",
+ Token.Output.EvenRow: "output.even-row",
+ Token.Output.Null: "output.null",
+ Token.Literal.String: "literal.string",
+ Token.Literal.Number: "literal.number",
+ Token.Keyword: "keyword",
+ Token.Prompt: "prompt",
+ Token.Continuation: "continuation",
+}
+
+# reverse dict for cli_helpers, because they still expect Pygments tokens.
+PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
+
+
+def parse_pygments_style(token_name, style_object, style_dict):
+ """Parse token type and style string.
+
+ :param token_name: str name of Pygments token. Example: "Token.String"
+ :param style_object: pygments.style.Style instance to use as base
+ :param style_dict: dict of token names and their styles, customized to this cli
+
+ """
+ token_type = string_to_tokentype(token_name)
+ try:
+ other_token_type = string_to_tokentype(style_dict[token_name])
+ return token_type, style_object.styles[other_token_type]
+ except AttributeError:
+ return token_type, style_dict[token_name]
+
+
+def style_factory(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name)
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native")
+
+ prompt_styles = []
+ # prompt-toolkit used pygments tokens for styling before, switched to style
+ # names in 2.0. Convert old token types to new style names, for backwards compatibility.
+ for token in cli_style:
+ if token.startswith("Token."):
+ # treat as pygments token (1.0)
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ if token_type in TOKEN_TO_PROMPT_STYLE:
+ prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
+ prompt_styles.append((prompt_style, style_value))
+ else:
+ # we don't want to support tokens anymore
+ logger.error("Unhandled style / class name: %s", token)
+ else:
+ # treat as prompt style name (2.0). See default style names here:
+ # https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/src/prompt_toolkit/styles/defaults.py
+ prompt_styles.append((token, cli_style[token]))
+
+ override_style = Style([("bottom-toolbar", "noreverse")])
+ return merge_styles(
+ [style_from_pygments_cls(style), override_style, Style(prompt_styles)]
+ )
+
+
+def style_factory_output(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name).styles
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native").styles
+
+ for token in cli_style:
+ if token.startswith("Token."):
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ style.update({token_type: style_value})
+ elif token in PROMPT_STYLE_TO_TOKEN:
+ token_type = PROMPT_STYLE_TO_TOKEN[token]
+ style.update({token_type: cli_style[token]})
+ else:
+ # TODO: cli helpers will have to switch to ptk.Style
+ logger.error("Unhandled style / class name: %s", token)
+
+ class OutputStyle(PygmentsStyle):
+ default_style = ""
+ styles = style
+
+ return OutputStyle
diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py
new file mode 100644
index 0000000..4a12ff4
--- /dev/null
+++ b/pgcli/pgtoolbar.py
@@ -0,0 +1,71 @@
+from prompt_toolkit.key_binding.vi_state import InputMode
+from prompt_toolkit.application import get_app
+
+vi_modes = {
+ InputMode.INSERT: "I",
+ InputMode.NAVIGATION: "N",
+ InputMode.REPLACE: "R",
+ InputMode.INSERT_MULTIPLE: "M",
+}
+# REPLACE_SINGLE is available in prompt_toolkit >= 3.0.6
+if "REPLACE_SINGLE" in {e.name for e in InputMode}:
+ vi_modes[InputMode.REPLACE_SINGLE] = "R"
+
+
+def _get_vi_mode():
+ return vi_modes[get_app().vi_state.input_mode]
+
+
+def create_toolbar_tokens_func(pgcli):
+ """Return a function that generates the toolbar tokens."""
+
+ def get_toolbar_tokens():
+ result = []
+ result.append(("class:bottom-toolbar", " "))
+
+ if pgcli.completer.smart_completion:
+ result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON "))
+ else:
+ result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF "))
+
+ if pgcli.multi_line:
+ result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
+ else:
+ result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
+
+ if pgcli.multi_line:
+ if pgcli.multiline_mode == "safe":
+ result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) "))
+ else:
+ result.append(
+ ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")
+ )
+
+ if pgcli.vi_mode:
+ result.append(
+ ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ") ")
+ )
+ else:
+ result.append(("class:bottom-toolbar", "[F4] Emacs-mode "))
+
+ if pgcli.explain_mode:
+ result.append(("class:bottom-toolbar", "[F5] Explain: ON "))
+ else:
+ result.append(("class:bottom-toolbar", "[F5] Explain: OFF "))
+
+ if pgcli.pgexecute.failed_transaction():
+ result.append(
+ ("class:bottom-toolbar.transaction.failed", " Failed transaction")
+ )
+
+ if pgcli.pgexecute.valid_transaction():
+ result.append(
+ ("class:bottom-toolbar.transaction.valid", " Transaction")
+ )
+
+ if pgcli.completion_refresher.is_refreshing():
+ result.append(("class:bottom-toolbar", " Refreshing completions..."))
+
+ return result
+
+ return get_toolbar_tokens
diff --git a/pgcli/pyev.py b/pgcli/pyev.py
new file mode 100644
index 0000000..2886c9c
--- /dev/null
+++ b/pgcli/pyev.py
@@ -0,0 +1,439 @@
+import textwrap
+import re
+from click import style as color
+
+DESCRIPTIONS = {
+ "Append": "Used in a UNION to merge multiple record sets by appending them together.",
+ "Limit": "Returns a specified number of rows from a record set.",
+ "Sort": "Sorts a record set based on the specified sort key.",
+ "Nested Loop": "Merges two record sets by looping through every record in the first set and trying to find a match in the second set. All matching records are returned.",
+ "Merge Join": "Merges two record sets by first sorting them on a join key.",
+ "Hash": "Generates a hash table from the records in the input recordset. Hash is used by Hash Join.",
+ "Hash Join": "Joins to record sets by hashing one of them (using a Hash Scan).",
+ "Aggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()).",
+ "Hashaggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()). Hash Aggregate uses a hash to first organize the records by a key.",
+ "Sequence Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).",
+ "Seq Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).",
+ "Index Scan": "Finds relevant records based on an Index. Index Scans perform 2 read operations: one to read the index and another to read the actual value from the table.",
+ "Index Only Scan": "Finds relevant records based on an Index. Index Only Scans perform a single read operation from the index and do not read from the corresponding table.",
+ "Bitmap Heap Scan": "Searches through the pages returned by the Bitmap Index Scan for relevant rows.",
+ "Bitmap Index Scan": "Uses a Bitmap Index (index which uses 1 bit per page) to find all relevant pages. Results of this node are fed to the Bitmap Heap Scan.",
+ "CTEScan": "Performs a sequential scan of Common Table Expression (CTE) query results. Note that results of a CTE are materialized (calculated and temporarily stored).",
+ "ProjectSet": "ProjectSet appears when the SELECT or ORDER BY clause of the query. They basically just execute the set-returning function(s) for each tuple until none of the functions return any more records.",
+ "Result": "Returns result",
+}
+
+
+class Visualizer:
+ def __init__(self, terminal_width=100, color=True):
+ self.color = color
+ self.terminal_width = terminal_width
+ self.string_lines = []
+
+ def load(self, explain_dict):
+ self.plan = explain_dict.pop("Plan")
+ self.explain = explain_dict
+ self.process_all()
+ self.generate_lines()
+
+ def process_all(self):
+ self.plan = self.process_plan(self.plan)
+ self.plan = self.calculate_outlier_nodes(self.plan)
+
+ #
+ def process_plan(self, plan):
+ plan = self.calculate_planner_estimate(plan)
+ plan = self.calculate_actuals(plan)
+ self.calculate_maximums(plan)
+ #
+ for index in range(len(plan.get("Plans", []))):
+ _plan = plan["Plans"][index]
+ plan["Plans"][index] = self.process_plan(_plan)
+ return plan
+
+ def prefix_format(self, v):
+ if self.color:
+ return color(v, fg="bright_black")
+ return v
+
+ def tag_format(self, v):
+ if self.color:
+ return color(v, fg="white", bg="red")
+ return v
+
+ def muted_format(self, v):
+ if self.color:
+ return color(v, fg="bright_black")
+ return v
+
+ def bold_format(self, v):
+ if self.color:
+ return color(v, fg="white")
+ return v
+
+ def good_format(self, v):
+ if self.color:
+ return color(v, fg="green")
+ return v
+
+ def warning_format(self, v):
+ if self.color:
+ return color(v, fg="yellow")
+ return v
+
+ def critical_format(self, v):
+ if self.color:
+ return color(v, fg="red")
+ return v
+
+ def output_format(self, v):
+ if self.color:
+ return color(v, fg="cyan")
+ return v
+
+ def calculate_planner_estimate(self, plan):
+ plan["Planner Row Estimate Factor"] = 0
+ plan["Planner Row Estimate Direction"] = "Under"
+
+ if plan["Plan Rows"] == plan["Actual Rows"]:
+ return plan
+
+ if plan["Plan Rows"] != 0:
+ plan["Planner Row Estimate Factor"] = (
+ plan["Actual Rows"] / plan["Plan Rows"]
+ )
+
+ if plan["Planner Row Estimate Factor"] < 10:
+ plan["Planner Row Estimate Factor"] = 0
+ plan["Planner Row Estimate Direction"] = "Over"
+ if plan["Actual Rows"] != 0:
+ plan["Planner Row Estimate Factor"] = (
+ plan["Plan Rows"] / plan["Actual Rows"]
+ )
+ return plan
+
+ #
+ def calculate_actuals(self, plan):
+ plan["Actual Duration"] = plan["Actual Total Time"]
+ plan["Actual Cost"] = plan["Total Cost"]
+
+ for child in plan.get("Plans", []):
+ if child["Node Type"] != "CTEScan":
+ plan["Actual Duration"] = (
+ plan["Actual Duration"] - child["Actual Total Time"]
+ )
+ plan["Actual Cost"] = plan["Actual Cost"] - child["Total Cost"]
+
+ if plan["Actual Cost"] < 0:
+ plan["Actual Cost"] = 0
+
+ plan["Actual Duration"] = plan["Actual Duration"] * plan["Actual Loops"]
+ return plan
+
+ def calculate_outlier_nodes(self, plan):
+ plan["Costliest"] = plan["Actual Cost"] == self.explain["Max Cost"]
+ plan["Largest"] = plan["Actual Rows"] == self.explain["Max Rows"]
+ plan["Slowest"] = plan["Actual Duration"] == self.explain["Max Duration"]
+
+ for index in range(len(plan.get("Plans", []))):
+ _plan = plan["Plans"][index]
+ plan["Plans"][index] = self.calculate_outlier_nodes(_plan)
+ return plan
+
+ def calculate_maximums(self, plan):
+ if not self.explain.get("Max Rows"):
+ self.explain["Max Rows"] = plan["Actual Rows"]
+ elif self.explain.get("Max Rows") < plan["Actual Rows"]:
+ self.explain["Max Rows"] = plan["Actual Rows"]
+
+ if not self.explain.get("Max Cost"):
+ self.explain["Max Cost"] = plan["Actual Cost"]
+ elif self.explain.get("Max Cost") < plan["Actual Cost"]:
+ self.explain["Max Cost"] = plan["Actual Cost"]
+
+ if not self.explain.get("Max Duration"):
+ self.explain["Max Duration"] = plan["Actual Duration"]
+ elif self.explain.get("Max Duration") < plan["Actual Duration"]:
+ self.explain["Max Duration"] = plan["Actual Duration"]
+
+ if not self.explain.get("Total Cost"):
+ self.explain["Total Cost"] = plan["Actual Cost"]
+ elif self.explain.get("Total Cost") < plan["Actual Cost"]:
+ self.explain["Total Cost"] = plan["Actual Cost"]
+
+ #
+ def duration_to_string(self, value):
+ if value < 1:
+ return self.good_format("<1 ms")
+ elif value < 100:
+ return self.good_format("%.2f ms" % value)
+ elif value < 1000:
+ return self.warning_format("%.2f ms" % value)
+ elif value < 60000:
+ return self.critical_format(
+ "%.2f s" % (value / 1000.0),
+ )
+ else:
+ return self.critical_format(
+ "%.2f m" % (value / 60000.0),
+ )
+
+ # }
+ #
+ def format_details(self, plan):
+ details = []
+
+ if plan.get("Scan Direction"):
+ details.append(plan["Scan Direction"])
+
+ if plan.get("Strategy"):
+ details.append(plan["Strategy"])
+
+ if len(details) > 0:
+ return self.muted_format(" [%s]" % ", ".join(details))
+
+ return ""
+
+ def format_tags(self, plan):
+ tags = []
+
+ if plan["Slowest"]:
+ tags.append(self.tag_format("slowest"))
+ if plan["Costliest"]:
+ tags.append(self.tag_format("costliest"))
+ if plan["Largest"]:
+ tags.append(self.tag_format("largest"))
+ if plan.get("Planner Row Estimate Factor", 0) >= 100:
+ tags.append(self.tag_format("bad estimate"))
+
+ return " ".join(tags)
+
+ def get_terminator(self, index, plan):
+ if index == 0:
+ if len(plan.get("Plans", [])) == 0:
+ return "⌡► "
+ else:
+ return "├► "
+ else:
+ if len(plan.get("Plans", [])) == 0:
+ return " "
+ else:
+ return "│ "
+
+ def wrap_string(self, line, width):
+ if width == 0:
+ return [line]
+ return textwrap.wrap(line, width)
+
+ def intcomma(self, value):
+ sep = ","
+ if not isinstance(value, str):
+ value = int(value)
+
+ orig = str(value)
+
+ new = re.sub(r"^(-?\d+)(\d{3})", rf"\g<1>{sep}\g<2>", orig)
+ if orig == new:
+ return new
+ else:
+ return self.intcomma(new)
+
+ def output_fn(self, current_prefix, string):
+ return "%s%s" % (self.prefix_format(current_prefix), string)
+
+ def create_lines(self, plan, prefix, depth, width, last_child):
+ current_prefix = prefix
+ self.string_lines.append(
+ self.output_fn(current_prefix, self.prefix_format("│"))
+ )
+
+ joint = "├"
+ if last_child:
+ joint = "└"
+ #
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s%s %s"
+ % (
+ self.prefix_format(joint + "─⌠"),
+ self.bold_format(plan["Node Type"]),
+ self.format_details(plan),
+ self.format_tags(plan),
+ ),
+ )
+ )
+ #
+ if last_child:
+ prefix += " "
+ else:
+ prefix += "│ "
+
+ current_prefix = prefix + "│ "
+
+ cols = width - len(current_prefix)
+
+ for line in self.wrap_string(
+ DESCRIPTIONS.get(plan["Node Type"], "Not found : %s" % plan["Node Type"]),
+ cols,
+ ):
+ self.string_lines.append(
+ self.output_fn(current_prefix, "%s" % self.muted_format(line))
+ )
+ #
+ if plan.get("Actual Duration"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "○ %s %s (%.0f%%)"
+ % (
+ "Duration:",
+ self.duration_to_string(plan["Actual Duration"]),
+ (plan["Actual Duration"] / self.explain["Execution Time"])
+ * 100,
+ ),
+ )
+ )
+
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "○ %s %s (%.0f%%)"
+ % (
+ "Cost:",
+ self.intcomma(plan["Actual Cost"]),
+ (plan["Actual Cost"] / self.explain["Total Cost"]) * 100,
+ ),
+ )
+ )
+
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "○ %s %s" % ("Rows:", self.intcomma(plan["Actual Rows"])),
+ )
+ )
+
+ current_prefix = current_prefix + " "
+
+ if plan.get("Join Type"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (plan["Join Type"], self.muted_format("join")),
+ )
+ )
+
+ if plan.get("Relation Name"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s.%s"
+ % (
+ self.muted_format("on"),
+ plan.get("Schema", "unknown"),
+ plan["Relation Name"],
+ ),
+ )
+ )
+
+ if plan.get("Index Name"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (self.muted_format("using"), plan["Index Name"]),
+ )
+ )
+
+ if plan.get("Index Condition"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (self.muted_format("condition"), plan["Index Condition"]),
+ )
+ )
+
+ if plan.get("Filter"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s %s"
+ % (
+ self.muted_format("filter"),
+ plan["Filter"],
+ self.muted_format(
+ "[-%s rows]" % self.intcomma(plan["Rows Removed by Filter"])
+ ),
+ ),
+ )
+ )
+
+ if plan.get("Hash Condition"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (self.muted_format("on"), plan["Hash Condition"]),
+ )
+ )
+
+ if plan.get("CTE Name"):
+ self.string_lines.append(
+ self.output_fn(current_prefix, "CTE %s" % plan["CTE Name"])
+ )
+
+ if plan.get("Planner Row Estimate Factor") != 0:
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %sestimated %s %.2fx"
+ % (
+ self.muted_format("rows"),
+ plan["Planner Row Estimate Direction"],
+ self.muted_format("by"),
+ plan["Planner Row Estimate Factor"],
+ ),
+ )
+ )
+
+ current_prefix = prefix
+
+ if len(plan.get("Output", [])) > 0:
+ for index, line in enumerate(
+ self.wrap_string(" + ".join(plan["Output"]), cols)
+ ):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ self.prefix_format(self.get_terminator(index, plan))
+ + self.output_format(line),
+ )
+ )
+
+ for index, nested_plan in enumerate(plan.get("Plans", [])):
+ self.create_lines(
+ nested_plan, prefix, depth + 1, width, index == len(plan["Plans"]) - 1
+ )
+
+ def generate_lines(self):
+ self.string_lines = [
+ "○ Total Cost: %s" % self.intcomma(self.explain["Total Cost"]),
+ "○ Planning Time: %s"
+ % self.duration_to_string(self.explain["Planning Time"]),
+ "○ Execution Time: %s"
+ % self.duration_to_string(self.explain["Execution Time"]),
+ self.prefix_format("┬"),
+ ]
+ self.create_lines(
+ self.plan,
+ "",
+ 0,
+ self.terminal_width,
+ len(self.plan.get("Plans", [])) == 1,
+ )
+
+ def get_list(self):
+ return "\n".join(self.string_lines)
+
+ def print(self):
+ for lin in self.string_lines:
+ print(lin)
diff --git a/post-install b/post-install
new file mode 100644
index 0000000..d516a3f
--- /dev/null
+++ b/post-install
@@ -0,0 +1,4 @@
+#!/bin/bash
+
+echo "Setting up symlink to pgcli"
+ln -sf /usr/share/pgcli/bin/pgcli /usr/local/bin/pgcli
diff --git a/post-remove b/post-remove
new file mode 100644
index 0000000..1013eb4
--- /dev/null
+++ b/post-remove
@@ -0,0 +1,4 @@
+#!/bin/bash
+
+echo "Removing symlink to pgcli"
+rm /usr/local/bin/pgcli
diff --git a/pylintrc b/pylintrc
new file mode 100644
index 0000000..4aa448d
--- /dev/null
+++ b/pylintrc
@@ -0,0 +1,2 @@
+[MESSAGES CONTROL]
+disable=missing-docstring,invalid-name \ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..8477d72
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,21 @@
+[tool.black]
+line-length = 88
+target-version = ['py38']
+include = '\.pyi?$'
+exclude = '''
+/(
+ \.eggs
+ | \.git
+ | \.hg
+ | \.mypy_cache
+ | \.tox
+ | \.venv
+ | \.cache
+ | \.pytest_cache
+ | _build
+ | buck-out
+ | build
+ | dist
+ | tests/data
+)/
+'''
diff --git a/release.py b/release.py
new file mode 100644
index 0000000..42a72a9
--- /dev/null
+++ b/release.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python
+"""A script to publish a release of pgcli to PyPI."""
+
+import io
+from optparse import OptionParser
+import re
+import subprocess
+import sys
+
+import click
+
+DEBUG = False
+CONFIRM_STEPS = False
+DRY_RUN = False
+
+
+def skip_step():
+ """
+ Asks for user's response whether to run a step. Default is yes.
+ :return: boolean
+ """
+ global CONFIRM_STEPS
+
+ if CONFIRM_STEPS:
+ return not click.confirm("--- Run this step?", default=True)
+ return False
+
+
+def run_step(*args):
+ """
+ Prints out the command and asks if it should be run.
+ If yes (default), runs it.
+ :param args: list of strings (command and args)
+ """
+ global DRY_RUN
+
+ cmd = args
+ print(" ".join(cmd))
+ if skip_step():
+ print("--- Skipping...")
+ elif DRY_RUN:
+ print("--- Pretending to run...")
+ else:
+ subprocess.check_output(cmd)
+
+
+def version(version_file):
+ _version_re = re.compile(
+ r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)'
+ )
+
+ with io.open(version_file, encoding="utf-8") as f:
+ ver = _version_re.search(f.read()).group("version")
+
+ return ver
+
+
+def commit_for_release(version_file, ver):
+ run_step("git", "reset")
+ run_step("git", "add", "-u")
+ run_step("git", "commit", "--message", "Releasing version {}".format(ver))
+
+
+def create_git_tag(tag_name):
+ run_step("git", "tag", tag_name)
+
+
+def create_distribution_files():
+ run_step("python", "setup.py", "clean", "--all", "sdist", "bdist_wheel")
+
+
+def upload_distribution_files():
+ run_step("twine", "upload", "dist/*")
+
+
+def push_to_github():
+ run_step("git", "push", "origin", "main")
+
+
+def push_tags_to_github():
+ run_step("git", "push", "--tags", "origin")
+
+
+def checklist(questions):
+ for question in questions:
+ if not click.confirm("--- {}".format(question), default=False):
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ if DEBUG:
+ subprocess.check_output = lambda x: x
+
+ checks = [
+ "Have you updated the AUTHORS file?",
+ "Have you updated the `Usage` section of the README?",
+ ]
+ checklist(checks)
+
+ ver = version("pgcli/__init__.py")
+ print("Releasing Version:", ver)
+
+ parser = OptionParser()
+ parser.add_option(
+ "-c",
+ "--confirm-steps",
+ action="store_true",
+ dest="confirm_steps",
+ default=False,
+ help=(
+ "Confirm every step. If the step is not " "confirmed, it will be skipped."
+ ),
+ )
+ parser.add_option(
+ "-d",
+ "--dry-run",
+ action="store_true",
+ dest="dry_run",
+ default=False,
+ help="Print out, but not actually run any steps.",
+ )
+
+ popts, pargs = parser.parse_args()
+ CONFIRM_STEPS = popts.confirm_steps
+ DRY_RUN = popts.dry_run
+
+ if not click.confirm("Are you sure?", default=False):
+ sys.exit(1)
+
+ commit_for_release("pgcli/__init__.py", ver)
+ create_git_tag("v{}".format(ver))
+ create_distribution_files()
+ push_to_github()
+ push_tags_to_github()
+ upload_distribution_files()
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 0000000..15505a7
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,13 @@
+pytest>=2.7.0
+tox>=1.9.2
+behave>=1.2.4
+black>=23.3.0
+pexpect==3.3; platform_system != "Windows"
+pre-commit>=1.16.0
+coverage>=5.0.4
+codecov>=1.5.1
+docutils>=0.13.1
+autopep8>=1.3.3
+twine>=1.11.0
+wheel>=0.33.6
+sshtunnel>=0.4.0
diff --git a/sanity_checks.txt b/sanity_checks.txt
new file mode 100644
index 0000000..d8a4898
--- /dev/null
+++ b/sanity_checks.txt
@@ -0,0 +1,37 @@
+# vi: ft=vimwiki
+
+* Launch pgcli with different inputs.
+ * pgcli test_db
+ * pgcli postgres://localhost/test_db
+ * pgcli postgres://localhost:5432/test_db
+ * pgcli postgres://amjith@localhost:5432/test_db
+ * pgcli postgres://amjith:password@localhost:5432/test_db
+ * pgcli non-existent-db
+
+* Test special command
+ * \d
+ * \d table_name
+ * \dt
+ * \l
+ * \c amjith
+ * \q
+
+* Simple execution:
+ 1 Execute a simple 'select * from users;' test that will pass.
+ 2 Execute a syntax error: 'insert into users ( ;'
+ 3 Execute a simple test from step 1 again to see if it still passes.
+ * Change the database and try steps 1 - 3.
+
+* Test smart-completion
+ * Sele - Must auto-complete to SELECT
+ * SELECT * FROM - Must list the table names.
+ * INSERT INTO - Must list table names.
+ * \d <tab> - Must list table names.
+ * \c <tab> - Database names.
+ * SELECT * FROM table_name WHERE <tab> - column names (all of it).
+
+* Test naive-completion - turn off smart completion (using F2 key after launch)
+ * Sele - autocomplete to select.
+ * SELECT * FROM - autocomplete list should have everything.
+
+
diff --git a/screenshots/image01.png b/screenshots/image01.png
new file mode 100644
index 0000000..58520c5
--- /dev/null
+++ b/screenshots/image01.png
Binary files differ
diff --git a/screenshots/image02.png b/screenshots/image02.png
new file mode 100644
index 0000000..c321c86
--- /dev/null
+++ b/screenshots/image02.png
Binary files differ
diff --git a/screenshots/kharkiv-destroyed.jpg b/screenshots/kharkiv-destroyed.jpg
new file mode 100644
index 0000000..4f95783
--- /dev/null
+++ b/screenshots/kharkiv-destroyed.jpg
Binary files differ
diff --git a/screenshots/pgcli.gif b/screenshots/pgcli.gif
new file mode 100644
index 0000000..9c8e66d
--- /dev/null
+++ b/screenshots/pgcli.gif
Binary files differ
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..9a398a4
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,76 @@
+import platform
+from setuptools import setup, find_packages
+
+from pgcli import __version__
+
+description = "CLI for Postgres Database. With auto-completion and syntax highlighting."
+
+install_requirements = [
+ "pgspecial>=2.0.0",
+ "click >= 4.1",
+ "Pygments>=2.0", # Pygments has to be Capitalcased. WTF?
+ # We still need to use pt-2 unless pt-3 released on Fedora32
+ # see: https://github.com/dbcli/pgcli/pull/1197
+ "prompt_toolkit>=2.0.6,<4.0.0",
+ "psycopg >= 3.0.14",
+ "sqlparse >=0.3.0,<0.5",
+ "configobj >= 5.0.6",
+ "pendulum>=2.1.0",
+ "cli_helpers[styles] >= 2.2.1",
+]
+
+
+# setproctitle is used to mask the password when running `ps` in command line.
+# But this is not necessary in Windows since the password is never shown in the
+# task manager. Also setproctitle is a hard dependency to install in Windows,
+# so we'll only install it if we're not in Windows.
+if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"):
+ install_requirements.append("setproctitle >= 1.1.9")
+
+# Windows will require the binary psycopg to run pgcli
+if platform.system() == "Windows":
+ install_requirements.append("psycopg-binary >= 3.0.14")
+
+
+setup(
+ name="pgcli",
+ author="Pgcli Core Team",
+ author_email="pgcli-dev@googlegroups.com",
+ version=__version__,
+ license="BSD",
+ url="http://pgcli.com",
+ packages=find_packages(),
+ package_data={"pgcli": ["pgclirc", "packages/pgliterals/pgliterals.json"]},
+ description=description,
+ long_description=open("README.rst").read(),
+ install_requires=install_requirements,
+ dependency_links=[
+ "http://github.com/psycopg/repo/tarball/master#egg=psycopg-3.0.10"
+ ],
+ extras_require={
+ "keyring": ["keyring >= 12.2.0"],
+ "sshtunnel": ["sshtunnel >= 0.4.0"],
+ },
+ python_requires=">=3.8",
+ entry_points="""
+ [console_scripts]
+ pgcli=pgcli.main:cli
+ """,
+ classifiers=[
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: BSD License",
+ "Operating System :: Unix",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: SQL",
+ "Topic :: Database",
+ "Topic :: Database :: Front-Ends",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ ],
+)
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..33cddf2
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,52 @@
+import os
+import pytest
+from utils import (
+ POSTGRES_HOST,
+ POSTGRES_PORT,
+ POSTGRES_USER,
+ POSTGRES_PASSWORD,
+ create_db,
+ db_connection,
+ drop_tables,
+)
+import pgcli.pgexecute
+
+
+@pytest.fixture(scope="function")
+def connection():
+ create_db("_test_db")
+ connection = db_connection("_test_db")
+ yield connection
+
+ drop_tables(connection)
+ connection.close()
+
+
+@pytest.fixture
+def cursor(connection):
+ with connection.cursor() as cur:
+ return cur
+
+
+@pytest.fixture
+def executor(connection):
+ return pgcli.pgexecute.PGExecute(
+ database="_test_db",
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ port=POSTGRES_PORT,
+ dsn=None,
+ )
+
+
+@pytest.fixture
+def exception_formatter():
+ return lambda e: str(e)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def temp_config(tmpdir_factory):
+ # this function runs on start of test session.
+ # use temporary directory for config home so user config will not be used
+ os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data"))
diff --git a/tests/features/__init__.py b/tests/features/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/features/__init__.py
diff --git a/tests/features/auto_vertical.feature b/tests/features/auto_vertical.feature
new file mode 100644
index 0000000..aa95718
--- /dev/null
+++ b/tests/features/auto_vertical.feature
@@ -0,0 +1,12 @@
+Feature: auto_vertical mode:
+ on, off
+
+ Scenario: auto_vertical on with small query
+ When we run dbcli with --auto-vertical-output
+ and we execute a small query
+ then we see small results in horizontal format
+
+ Scenario: auto_vertical on with large query
+ When we run dbcli with --auto-vertical-output
+ and we execute a large query
+ then we see large results in vertical format
diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature
new file mode 100644
index 0000000..ee497b9
--- /dev/null
+++ b/tests/features/basic_commands.feature
@@ -0,0 +1,81 @@
+Feature: run the cli,
+ call the help command,
+ exit the cli
+
+ Scenario: run "\?" command
+ When we send "\?" command
+ then we see help output
+
+ Scenario: run source command
+ When we send source command
+ then we see help output
+
+ Scenario: run partial select command
+ When we send partial select command
+ then we see error message
+ then we see dbcli prompt
+
+ Scenario: check our application_name
+ When we run query to check application_name
+ then we see found
+
+ Scenario: run the cli and exit
+ When we send "ctrl + d"
+ then dbcli exits
+
+ Scenario: confirm exit when a transaction is ongoing
+ When we begin transaction
+ and we try to send "ctrl + d"
+ then we see ongoing transaction message
+ when we send "c"
+ then dbcli exits
+
+ Scenario: cancel exit when a transaction is ongoing
+ When we begin transaction
+ and we try to send "ctrl + d"
+ then we see ongoing transaction message
+ when we send "a"
+ then we see dbcli prompt
+ when we rollback transaction
+ when we send "ctrl + d"
+ then dbcli exits
+
+ Scenario: interrupt current query via "ctrl + c"
+ When we send sleep query
+ and we send "ctrl + c"
+ then we see cancelled query warning
+ when we check for any non-idle sleep queries
+ then we don't see any non-idle sleep queries
+
+ Scenario: list databases
+ When we list databases
+ then we see list of databases
+
+ Scenario: run the cli with --username
+ When we launch dbcli using --username
+ and we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with --user
+ When we launch dbcli using --user
+ and we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with --port
+ When we launch dbcli using --port
+ and we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with --password
+ When we launch dbcli using --password
+ then we send password
+ and we see dbcli prompt
+ when we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with dsn and password
+ When we launch dbcli using dsn_password
+ then we send password
+ and we see dbcli prompt
+ when we send "\?" command
+ then we see help output
diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature
new file mode 100644
index 0000000..87da4e3
--- /dev/null
+++ b/tests/features/crud_database.feature
@@ -0,0 +1,17 @@
+Feature: manipulate databases:
+ create, drop, connect, disconnect
+
+ Scenario: create and drop temporary database
+ When we create database
+ then we see database created
+ when we drop database
+ then we respond to the destructive warning: y
+ then we see database dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: connect and disconnect from test database
+ When we connect to test database
+ then we see database connected
+ when we connect to dbserver
+ then we see database connected
diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature
new file mode 100644
index 0000000..8a43c5c
--- /dev/null
+++ b/tests/features/crud_table.feature
@@ -0,0 +1,45 @@
+Feature: manipulate tables:
+ create, insert, update, select, delete from, drop
+
+ Scenario: create, insert, select from, update, drop table
+ When we connect to test database
+ then we see database connected
+ when we create table
+ then we see table created
+ when we insert into table
+ then we see record inserted
+ when we select from table
+ then we see data selected: initial
+ when we update table
+ then we see record updated
+ when we select from table
+ then we see data selected: updated
+ when we delete from table
+ then we respond to the destructive warning: y
+ then we see record deleted
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: transaction handling, with cancelling on a destructive warning.
+ When we connect to test database
+ then we see database connected
+ when we create table
+ then we see table created
+ when we begin transaction
+ then we see transaction began
+ when we insert into table
+ then we see record inserted
+ when we delete from table
+ then we respond to the destructive warning: n
+ when we select from table
+ then we see data selected: initial
+ when we rollback transaction
+ then we see transaction rolled back
+ when we select from table
+ then we see select output without data
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py
new file mode 100644
index 0000000..595c6c2
--- /dev/null
+++ b/tests/features/db_utils.py
@@ -0,0 +1,87 @@
+from psycopg import connect
+
+
+def create_db(
+ hostname="localhost", username=None, password=None, dbname=None, port=None
+):
+ """Create test database.
+
+ :param hostname: string
+ :param username: string
+ :param password: string
+ :param dbname: string
+ :param port: int
+ :return:
+
+ """
+ cn = create_cn(hostname, password, username, "postgres", port)
+
+ cn.autocommit = True
+ with cn.cursor() as cr:
+ cr.execute(f"drop database if exists {dbname}")
+ cr.execute(f"create database {dbname}")
+
+ cn.close()
+
+ cn = create_cn(hostname, password, username, dbname, port)
+ return cn
+
+
+def create_cn(hostname, password, username, dbname, port):
+ """
+ Open connection to database.
+ :param hostname:
+ :param password:
+ :param username:
+ :param dbname: string
+ :return: psycopg2.connection
+ """
+ cn = connect(
+ host=hostname, user=username, dbname=dbname, password=password, port=port
+ )
+
+ print(f"Created connection: {cn.info.get_parameters()}.")
+ return cn
+
+
+def pgbouncer_available(hostname="localhost", password=None, username="postgres"):
+ cn = None
+ try:
+ cn = create_cn(hostname, password, username, "pgbouncer", 6432)
+ return True
+ except:
+ print("Pgbouncer is not available.")
+ finally:
+ if cn:
+ cn.close()
+ return False
+
+
+def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None):
+ """
+ Drop database.
+ :param hostname: string
+ :param username: string
+ :param password: string
+ :param dbname: string
+ """
+ cn = create_cn(hostname, password, username, "postgres", port)
+
+ # Needed for DB drop.
+ cn.autocommit = True
+
+ with cn.cursor() as cr:
+ cr.execute(f"drop database if exists {dbname}")
+
+ close_cn(cn)
+
+
+def close_cn(cn=None):
+ """
+ Close connection.
+ :param connection: psycopg2.connection
+ """
+ if cn:
+ cn_params = cn.info.get_parameters()
+ cn.close()
+ print(f"Closed connection: {cn_params}.")
diff --git a/tests/features/environment.py b/tests/features/environment.py
new file mode 100644
index 0000000..50ac5fa
--- /dev/null
+++ b/tests/features/environment.py
@@ -0,0 +1,227 @@
+import copy
+import os
+import sys
+import db_utils as dbutils
+import fixture_utils as fixutils
+import pexpect
+import tempfile
+import shutil
+import signal
+
+
+from steps import wrappers
+
+
+def before_all(context):
+ """Set env parameters."""
+ env_old = copy.deepcopy(dict(os.environ))
+ os.environ["LINES"] = "100"
+ os.environ["COLUMNS"] = "100"
+ os.environ["PAGER"] = "cat"
+ os.environ["EDITOR"] = "ex"
+ os.environ["VISUAL"] = "ex"
+ os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
+
+ context.package_root = os.path.abspath(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
+ )
+ fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data")
+
+ print("package root:", context.package_root)
+ print("fixture dir:", fixture_dir)
+
+ os.environ["COVERAGE_PROCESS_START"] = os.path.join(
+ context.package_root, ".coveragerc"
+ )
+
+ context.exit_sent = False
+
+ vi = "_".join([str(x) for x in sys.version_info[:3]])
+ db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
+ db_name_full = f"{db_name}_{vi}"
+
+ # Store get params from config.
+ context.conf = {
+ "host": context.config.userdata.get(
+ "pg_test_host", os.getenv("PGHOST", "localhost")
+ ),
+ "user": context.config.userdata.get(
+ "pg_test_user", os.getenv("PGUSER", "postgres")
+ ),
+ "pass": context.config.userdata.get(
+ "pg_test_pass", os.getenv("PGPASSWORD", None)
+ ),
+ "port": context.config.userdata.get(
+ "pg_test_port", os.getenv("PGPORT", "5432")
+ ),
+ "cli_command": (
+ context.config.userdata.get("pg_cli_command", None)
+ or '{python} -c "{startup}"'.format(
+ python=sys.executable,
+ startup="; ".join(
+ [
+ "import coverage",
+ "coverage.process_startup()",
+ "import pgcli.main",
+ "pgcli.main.cli(auto_envvar_prefix='BEHAVE')",
+ ]
+ ),
+ )
+ ),
+ "dbname": db_name_full,
+ "dbname_tmp": db_name_full + "_tmp",
+ "vi": vi,
+ "pager_boundary": "---boundary---",
+ }
+ os.environ["PAGER"] = "{0} {1} {2}".format(
+ sys.executable,
+ os.path.join(context.package_root, "tests/features/wrappager.py"),
+ context.conf["pager_boundary"],
+ )
+
+ # Store old env vars.
+ context.pgenv = {
+ "PGDATABASE": os.environ.get("PGDATABASE", None),
+ "PGUSER": os.environ.get("PGUSER", None),
+ "PGHOST": os.environ.get("PGHOST", None),
+ "PGPASSWORD": os.environ.get("PGPASSWORD", None),
+ "PGPORT": os.environ.get("PGPORT", None),
+ "XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None),
+ "PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None),
+ }
+
+ # Set new env vars.
+ os.environ["PGDATABASE"] = context.conf["dbname"]
+ os.environ["PGUSER"] = context.conf["user"]
+ os.environ["PGHOST"] = context.conf["host"]
+ os.environ["PGPORT"] = context.conf["port"]
+ os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf")
+
+ if context.conf["pass"]:
+ os.environ["PGPASSWORD"] = context.conf["pass"]
+ else:
+ if "PGPASSWORD" in os.environ:
+ del os.environ["PGPASSWORD"]
+ os.environ["BEHAVE_WARN"] = "moderate"
+
+ context.cn = dbutils.create_db(
+ context.conf["host"],
+ context.conf["user"],
+ context.conf["pass"],
+ context.conf["dbname"],
+ context.conf["port"],
+ )
+ context.pgbouncer_available = dbutils.pgbouncer_available(
+ hostname=context.conf["host"],
+ password=context.conf["pass"],
+ username=context.conf["user"],
+ )
+ context.fixture_data = fixutils.read_fixture_files()
+
+ # use temporary directory as config home
+ context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_")
+ os.environ["XDG_CONFIG_HOME"] = context.env_config_home
+ show_env_changes(env_old, dict(os.environ))
+
+
+def show_env_changes(env_old, env_new):
+ """Print out all test-specific env values."""
+ print("--- os.environ changed values: ---")
+ all_keys = env_old.keys() | env_new.keys()
+ for k in sorted(all_keys):
+ old_value = env_old.get(k, "")
+ new_value = env_new.get(k, "")
+ if new_value and old_value != new_value:
+ print(f'{k}="{new_value}"')
+ print("-" * 20)
+
+
+def after_all(context):
+ """
+ Unset env parameters.
+ """
+ dbutils.close_cn(context.cn)
+ dbutils.drop_db(
+ context.conf["host"],
+ context.conf["user"],
+ context.conf["pass"],
+ context.conf["dbname"],
+ context.conf["port"],
+ )
+
+ # Remove temp config directory
+ shutil.rmtree(context.env_config_home)
+
+ # Restore env vars.
+ for k, v in context.pgenv.items():
+ if k in os.environ and v is None:
+ del os.environ[k]
+ elif v:
+ os.environ[k] = v
+
+
+def before_step(context, _):
+ context.atprompt = False
+
+
+def is_known_problem(scenario):
+ """TODO: why is this not working in 3.12?"""
+ if sys.version_info >= (3, 12):
+ return scenario.name in (
+ 'interrupt current query via "ctrl + c"',
+ "run the cli with --username",
+ "run the cli with --user",
+ "run the cli with --port",
+ )
+ return False
+
+
+def before_scenario(context, scenario):
+ if scenario.name == "list databases":
+ # not using the cli for that
+ return
+ if is_known_problem(scenario):
+ scenario.skip()
+ currentdb = None
+ if "pgbouncer" in scenario.feature.tags:
+ if context.pgbouncer_available:
+ os.environ["PGDATABASE"] = "pgbouncer"
+ os.environ["PGPORT"] = "6432"
+ currentdb = "pgbouncer"
+ else:
+ scenario.skip()
+ else:
+ # set env vars back to normal test database
+ os.environ["PGDATABASE"] = context.conf["dbname"]
+ os.environ["PGPORT"] = context.conf["port"]
+ wrappers.run_cli(context, currentdb=currentdb)
+ wrappers.wait_prompt(context)
+
+
+def after_scenario(context, scenario):
+ """Cleans up after each scenario completes."""
+ if hasattr(context, "cli") and context.cli and not context.exit_sent:
+ # Quit nicely.
+ if not getattr(context, "atprompt", False):
+ dbname = context.currentdb
+ context.cli.expect_exact(f"{dbname}>", timeout=5)
+ try:
+ context.cli.sendcontrol("c")
+ context.cli.sendcontrol("d")
+ except Exception as x:
+ print("Failed cleanup after scenario:")
+ print(x)
+ try:
+ context.cli.expect_exact(pexpect.EOF, timeout=5)
+ except pexpect.TIMEOUT:
+ print(f"--- after_scenario {scenario.name}: kill cli")
+ context.cli.kill(signal.SIGKILL)
+ if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
+ context.tmpfile_sql_help.close()
+ context.tmpfile_sql_help = None
+
+
+# # TODO: uncomment to debug a failure
+# def after_step(context, step):
+# if step.status == "failed":
+# import pdb; pdb.set_trace()
diff --git a/tests/features/expanded.feature b/tests/features/expanded.feature
new file mode 100644
index 0000000..e486048
--- /dev/null
+++ b/tests/features/expanded.feature
@@ -0,0 +1,29 @@
+Feature: expanded mode:
+ on, off, auto
+
+ Scenario: expanded on
+ When we prepare the test data
+ and we set expanded on
+ and we select from table
+ then we see expanded data selected
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
+
+ Scenario: expanded off
+ When we prepare the test data
+ and we set expanded off
+ and we select from table
+ then we see nonexpanded data selected
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
+
+ Scenario: expanded auto
+ When we prepare the test data
+ and we set expanded auto
+ and we select from table
+ then we see auto data selected
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
diff --git a/tests/features/fixture_data/help.txt b/tests/features/fixture_data/help.txt
new file mode 100644
index 0000000..bebb976
--- /dev/null
+++ b/tests/features/fixture_data/help.txt
@@ -0,0 +1,25 @@
++--------------------------+------------------------------------------------+
+| Command | Description |
+|--------------------------+------------------------------------------------|
+| \# | Refresh auto-completions. |
+| \? | Show Help. |
+| \T [format] | Change the table format used to output results |
+| \c[onnect] database_name | Change to a new database. |
+| \d [pattern] | List or describe tables, views and sequences. |
+| \dT[S+] [pattern] | List data types |
+| \df[+] [pattern] | List functions. |
+| \di[+] [pattern] | List indexes. |
+| \dn[+] [pattern] | List schemas. |
+| \ds[+] [pattern] | List sequences. |
+| \dt[+] [pattern] | List tables. |
+| \du[+] [pattern] | List roles. |
+| \dv[+] [pattern] | List views. |
+| \e [file] | Edit the query with external editor. |
+| \l | List databases. |
+| \n[+] [name] | List or execute named queries. |
+| \nd [name [query]] | Delete a named query. |
+| \ns name query | Save a named query. |
+| \refresh | Refresh auto-completions. |
+| \timing | Toggle timing of commands. |
+| \x | Toggle expanded output. |
++--------------------------+------------------------------------------------+
diff --git a/tests/features/fixture_data/help_commands.txt b/tests/features/fixture_data/help_commands.txt
new file mode 100644
index 0000000..e076661
--- /dev/null
+++ b/tests/features/fixture_data/help_commands.txt
@@ -0,0 +1,64 @@
+Command
+Description
+\#
+Refresh auto-completions.
+\?
+Show Commands.
+\T [format]
+Change the table format used to output results
+\c[onnect] database_name
+Change to a new database.
+\copy [tablename] to/from [filename]
+Copy data between a file and a table.
+\d[+] [pattern]
+List or describe tables, views and sequences.
+\dT[S+] [pattern]
+List data types
+\db[+] [pattern]
+List tablespaces.
+\df[+] [pattern]
+List functions.
+\di[+] [pattern]
+List indexes.
+\dm[+] [pattern]
+List materialized views.
+\dn[+] [pattern]
+List schemas.
+\ds[+] [pattern]
+List sequences.
+\dt[+] [pattern]
+List tables.
+\du[+] [pattern]
+List roles.
+\dv[+] [pattern]
+List views.
+\dx[+] [pattern]
+List extensions.
+\e [file]
+Edit the query with external editor.
+\h
+Show SQL syntax and help.
+\i filename
+Execute commands from file.
+\l
+List databases.
+\n[+] [name] [param1 param2 ...]
+List or execute named queries.
+\nd [name]
+Delete a named query.
+\ns name query
+Save a named query.
+\o [filename]
+Send all query results to file.
+\pager [command]
+Set PAGER. Print the query results via PAGER.
+\pset [key] [value]
+A limited version of traditional \pset
+\refresh
+Refresh auto-completions.
+\sf[+] FUNCNAME
+Show a function's definition.
+\timing
+Toggle timing of commands.
+\x
+Toggle expanded output.
diff --git a/tests/features/fixture_data/mock_pg_service.conf b/tests/features/fixture_data/mock_pg_service.conf
new file mode 100644
index 0000000..15f9811
--- /dev/null
+++ b/tests/features/fixture_data/mock_pg_service.conf
@@ -0,0 +1,4 @@
+[mock_postgres]
+dbname=postgres
+host=localhost
+user=postgres
diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py
new file mode 100644
index 0000000..70b603d
--- /dev/null
+++ b/tests/features/fixture_utils.py
@@ -0,0 +1,28 @@
+import os
+import codecs
+
+
+def read_fixture_lines(filename):
+ """
+ Read lines of text from file.
+ :param filename: string name
+ :return: list of strings
+ """
+ lines = []
+ for line in codecs.open(filename, "rb", encoding="utf-8"):
+ lines.append(line.strip())
+ return lines
+
+
+def read_fixture_files():
+ """Read all files inside fixture_data directory."""
+ current_dir = os.path.dirname(__file__)
+ fixture_dir = os.path.join(current_dir, "fixture_data/")
+ print(f"reading fixture data: {fixture_dir}")
+ fixture_dict = {}
+ for filename in os.listdir(fixture_dir):
+ if filename not in [".", ".."]:
+ fullname = os.path.join(fixture_dir, filename)
+ fixture_dict[filename] = read_fixture_lines(fullname)
+
+ return fixture_dict
diff --git a/tests/features/iocommands.feature b/tests/features/iocommands.feature
new file mode 100644
index 0000000..dad7d10
--- /dev/null
+++ b/tests/features/iocommands.feature
@@ -0,0 +1,17 @@
+Feature: I/O commands
+
+ Scenario: edit sql in file with external editor
+ When we start external editor providing a file name
+ and we type sql in the editor
+ and we exit the editor
+ then we see dbcli prompt
+ and we see the sql in prompt
+
+ Scenario: tee output from query
+ When we tee output
+ and we wait for prompt
+ and we query "select 123456"
+ and we wait for prompt
+ and we stop teeing output
+ and we wait for prompt
+ then we see 123456 in tee output
diff --git a/tests/features/named_queries.feature b/tests/features/named_queries.feature
new file mode 100644
index 0000000..74201b9
--- /dev/null
+++ b/tests/features/named_queries.feature
@@ -0,0 +1,10 @@
+Feature: named queries:
+ save, use and delete named queries
+
+ Scenario: save, use and delete named queries
+ When we connect to test database
+ then we see database connected
+ when we save a named query
+ then we see the named query saved
+ when we delete a named query
+ then we see the named query deleted
diff --git a/tests/features/pgbouncer.feature b/tests/features/pgbouncer.feature
new file mode 100644
index 0000000..14cc5ad
--- /dev/null
+++ b/tests/features/pgbouncer.feature
@@ -0,0 +1,12 @@
+@pgbouncer
+Feature: run pgbouncer,
+ call the help command,
+ exit the cli
+
+ Scenario: run "show help" command
+ When we send "show help" command
+ then we see the pgbouncer help output
+
+ Scenario: run the cli and exit
+ When we send "ctrl + d"
+ then dbcli exits
diff --git a/tests/features/specials.feature b/tests/features/specials.feature
new file mode 100644
index 0000000..63c5cdc
--- /dev/null
+++ b/tests/features/specials.feature
@@ -0,0 +1,6 @@
+Feature: Special commands
+
+ Scenario: run refresh command
+ When we refresh completions
+ and we wait for prompt
+ then we see completions refresh started
diff --git a/tests/features/steps/__init__.py b/tests/features/steps/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/features/steps/__init__.py
diff --git a/tests/features/steps/auto_vertical.py b/tests/features/steps/auto_vertical.py
new file mode 100644
index 0000000..d7cdccd
--- /dev/null
+++ b/tests/features/steps/auto_vertical.py
@@ -0,0 +1,99 @@
+from textwrap import dedent
+from behave import then, when
+import wrappers
+
+
+@when("we run dbcli with {arg}")
+def step_run_cli_with_arg(context, arg):
+ wrappers.run_cli(context, run_args=arg.split("="))
+
+
+@when("we execute a small query")
+def step_execute_small_query(context):
+ context.cli.sendline("select 1")
+
+
+@when("we execute a large query")
+def step_execute_large_query(context):
+ context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
+
+
+@then("we see small results in horizontal format")
+def step_see_small_results(context):
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ +----------+\r
+ | ?column? |\r
+ |----------|\r
+ | 1 |\r
+ +----------+\r
+ SELECT 1\r
+ """
+ ),
+ timeout=5,
+ )
+
+
+@then("we see large results in vertical format")
+def step_see_large_results(context):
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ -[ RECORD 1 ]-------------------------\r
+ ?column? | 1\r
+ ?column? | 2\r
+ ?column? | 3\r
+ ?column? | 4\r
+ ?column? | 5\r
+ ?column? | 6\r
+ ?column? | 7\r
+ ?column? | 8\r
+ ?column? | 9\r
+ ?column? | 10\r
+ ?column? | 11\r
+ ?column? | 12\r
+ ?column? | 13\r
+ ?column? | 14\r
+ ?column? | 15\r
+ ?column? | 16\r
+ ?column? | 17\r
+ ?column? | 18\r
+ ?column? | 19\r
+ ?column? | 20\r
+ ?column? | 21\r
+ ?column? | 22\r
+ ?column? | 23\r
+ ?column? | 24\r
+ ?column? | 25\r
+ ?column? | 26\r
+ ?column? | 27\r
+ ?column? | 28\r
+ ?column? | 29\r
+ ?column? | 30\r
+ ?column? | 31\r
+ ?column? | 32\r
+ ?column? | 33\r
+ ?column? | 34\r
+ ?column? | 35\r
+ ?column? | 36\r
+ ?column? | 37\r
+ ?column? | 38\r
+ ?column? | 39\r
+ ?column? | 40\r
+ ?column? | 41\r
+ ?column? | 42\r
+ ?column? | 43\r
+ ?column? | 44\r
+ ?column? | 45\r
+ ?column? | 46\r
+ ?column? | 47\r
+ ?column? | 48\r
+ ?column? | 49\r
+ SELECT 1\r
+ """
+ ),
+ timeout=5,
+ )
diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py
new file mode 100644
index 0000000..687bdc0
--- /dev/null
+++ b/tests/features/steps/basic_commands.py
@@ -0,0 +1,231 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+import pexpect
+import subprocess
+import tempfile
+
+from behave import when, then
+from textwrap import dedent
+import wrappers
+
+
+@when("we list databases")
+def step_list_databases(context):
+ cmd = ["pgcli", "--list"]
+ context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root)
+
+
+@then("we see list of databases")
+def step_see_list_databases(context):
+ assert b"List of databases" in context.cmd_output
+ assert b"postgres" in context.cmd_output
+ context.cmd_output = None
+
+
+@when("we run dbcli")
+def step_run_cli(context):
+ wrappers.run_cli(context)
+
+
+@when("we launch dbcli using {arg}")
+def step_run_cli_using_arg(context, arg):
+ prompt_check = False
+ currentdb = None
+ if arg == "--username":
+ arg = "--username={}".format(context.conf["user"])
+ if arg == "--user":
+ arg = "--user={}".format(context.conf["user"])
+ if arg == "--port":
+ arg = "--port={}".format(context.conf["port"])
+ if arg == "--password":
+ arg = "--password"
+ prompt_check = False
+ # This uses the mock_pg_service.conf file in fixtures folder.
+ if arg == "dsn_password":
+ arg = "service=mock_postgres --password"
+ prompt_check = False
+ currentdb = "postgres"
+ wrappers.run_cli(
+ context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb
+ )
+
+
+@when("we wait for prompt")
+def step_wait_prompt(context):
+ wrappers.wait_prompt(context)
+
+
+@when('we send "ctrl + d"')
+def step_ctrl_d(context):
+ """
+ Send Ctrl + D to hopefully exit.
+ """
+ step_try_to_ctrl_d(context)
+ context.cli.expect(pexpect.EOF, timeout=5)
+ context.exit_sent = True
+
+
+@when('we try to send "ctrl + d"')
+def step_try_to_ctrl_d(context):
+ """
+ Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is
+ ongoing).
+ """
+ # turn off pager before exiting
+ context.cli.sendcontrol("c")
+ context.cli.sendline(r"\pset pager off")
+ wrappers.wait_prompt(context)
+ context.cli.sendcontrol("d")
+
+
+@when('we send "ctrl + c"')
+def step_ctrl_c(context):
+ """Send Ctrl + c to hopefully interrupt."""
+ context.cli.sendcontrol("c")
+
+
+@then("we see cancelled query warning")
+def step_see_cancelled_query_warning(context):
+ """
+ Make sure we receive the warning that the current query was cancelled.
+ """
+ wrappers.expect_exact(context, "cancelled query", timeout=2)
+
+
+@then("we see ongoing transaction message")
+def step_see_ongoing_transaction_error(context):
+ """
+ Make sure we receive the warning that a transaction is ongoing.
+ """
+ context.cli.expect("A transaction is ongoing.", timeout=2)
+
+
+@when("we send sleep query")
+def step_send_sleep_15_seconds(context):
+ """
+ Send query to sleep for 15 seconds.
+ """
+ context.cli.sendline("select pg_sleep(15)")
+
+
+@when("we check for any non-idle sleep queries")
+def step_check_for_active_sleep_queries(context):
+ """
+ Send query to check for any non-idle pg_sleep queries.
+ """
+ context.cli.sendline(
+ "select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';"
+ )
+
+
+@then("we don't see any non-idle sleep queries")
+def step_no_active_sleep_queries(context):
+ """Confirm that any pg_sleep queries are either idle or not active."""
+ wrappers.expect_exact(
+ context,
+ context.conf["pager_boundary"]
+ + "\r"
+ + dedent(
+ """
+ +-------+\r
+ | state |\r
+ |-------|\r
+ +-------+\r
+ SELECT 0\r
+ """
+ )
+ + context.conf["pager_boundary"],
+ timeout=5,
+ )
+
+
+@when(r'we send "\?" command')
+def step_send_help(context):
+ r"""
+ Send \? to see help.
+ """
+ context.cli.sendline(r"\?")
+
+
+@when("we send partial select command")
+def step_send_partial_select_command(context):
+ """
+ Send `SELECT a` to see completion.
+ """
+ context.cli.sendline("SELECT a")
+
+
+@then("we see error message")
+def step_see_error_message(context):
+ wrappers.expect_exact(context, 'column "a" does not exist', timeout=2)
+
+
+@when("we send source command")
+def step_send_source_command(context):
+ context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
+ context.tmpfile_sql_help.write(rb"\?")
+ context.tmpfile_sql_help.flush()
+ context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}")
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
+
+
+@when("we run query to check application_name")
+def step_check_application_name(context):
+ context.cli.sendline(
+ "SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;"
+ )
+
+
+@then("we see found")
+def step_see_found(context):
+ wrappers.expect_exact(
+ context,
+ context.conf["pager_boundary"]
+ + "\r"
+ + dedent(
+ """
+ +----------+\r
+ | ?column? |\r
+ |----------|\r
+ | found |\r
+ +----------+\r
+ SELECT 1\r
+ """
+ )
+ + context.conf["pager_boundary"],
+ timeout=5,
+ )
+
+
+@then("we respond to the destructive warning: {response}")
+def step_resppond_to_destructive_command(context, response):
+ """Respond to destructive command."""
+ wrappers.expect_exact(
+ context,
+ "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:",
+ timeout=2,
+ )
+ context.cli.sendline(response.strip())
+
+
+@then("we send password")
+def step_send_password(context):
+ wrappers.expect_exact(context, "Password for", timeout=5)
+ context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER")
+
+
+@when('we send "{text}"')
+def step_send_text(context, text):
+ context.cli.sendline(text)
+ # Try to detect whether we are exiting. If so, set `exit_sent`
+ # so that `after_scenario` correctly cleans up.
+ try:
+ context.cli.expect(pexpect.EOF, timeout=0.2)
+ except pexpect.TIMEOUT:
+ pass
+ else:
+ context.exit_sent = True
diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py
new file mode 100644
index 0000000..87cdc85
--- /dev/null
+++ b/tests/features/steps/crud_database.py
@@ -0,0 +1,93 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+import pexpect
+
+from behave import when, then
+import wrappers
+
+
+@when("we create database")
+def step_db_create(context):
+ """
+ Send create database.
+ """
+ context.cli.sendline("create database {};".format(context.conf["dbname_tmp"]))
+
+ context.response = {"database_name": context.conf["dbname_tmp"]}
+
+
+@when("we drop database")
+def step_db_drop(context):
+ """
+ Send drop database.
+ """
+ context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"]))
+
+
+@when("we connect to test database")
+def step_db_connect_test(context):
+ """
+ Send connect to database.
+ """
+ db_name = context.conf["dbname"]
+ context.cli.sendline(f"\\connect {db_name}")
+
+
+@when("we connect to dbserver")
+def step_db_connect_dbserver(context):
+ """
+ Send connect to database.
+ """
+ context.cli.sendline("\\connect postgres")
+ context.currentdb = "postgres"
+
+
+@then("dbcli exits")
+def step_wait_exit(context):
+ """
+ Make sure the cli exits.
+ """
+ wrappers.expect_exact(context, pexpect.EOF, timeout=5)
+
+
+@then("we see dbcli prompt")
+def step_see_prompt(context):
+ """
+ Wait to see the prompt.
+ """
+ db_name = getattr(context, "currentdb", context.conf["dbname"])
+ wrappers.expect_exact(context, f"{db_name}>", timeout=5)
+ context.atprompt = True
+
+
+@then("we see help output")
+def step_see_help(context):
+ for expected_line in context.fixture_data["help_commands.txt"]:
+ wrappers.expect_exact(context, expected_line, timeout=2)
+
+
+@then("we see database created")
+def step_see_db_created(context):
+ """
+ Wait to see create database output.
+ """
+ wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5)
+
+
+@then("we see database dropped")
+def step_see_db_dropped(context):
+ """
+ Wait to see drop database output.
+ """
+ wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2)
+
+
+@then("we see database connected")
+def step_see_db_connected(context):
+ """
+ Wait to see drop database output.
+ """
+ wrappers.expect_exact(context, "You are now connected to database", timeout=2)
diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py
new file mode 100644
index 0000000..27d543e
--- /dev/null
+++ b/tests/features/steps/crud_table.py
@@ -0,0 +1,185 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+from textwrap import dedent
+import wrappers
+
+
+INITIAL_DATA = "xxx"
+UPDATED_DATA = "yyy"
+
+
+@when("we create table")
+def step_create_table(context):
+ """
+ Send create table.
+ """
+ context.cli.sendline("create table a(x text);")
+
+
+@when("we insert into table")
+def step_insert_into_table(context):
+ """
+ Send insert into table.
+ """
+ context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""")
+
+
+@when("we update table")
+def step_update_table(context):
+ """
+ Send insert into table.
+ """
+ context.cli.sendline(
+ f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';"""
+ )
+
+
+@when("we select from table")
+def step_select_from_table(context):
+ """
+ Send select from table.
+ """
+ context.cli.sendline("select * from a;")
+
+
+@when("we delete from table")
+def step_delete_from_table(context):
+ """
+ Send deete from table.
+ """
+ context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""")
+
+
+@when("we drop table")
+def step_drop_table(context):
+ """
+ Send drop table.
+ """
+ context.cli.sendline("drop table a;")
+
+
+@when("we alter the table")
+def step_alter_table(context):
+ """
+ Alter the table by adding a column.
+ """
+ context.cli.sendline("""alter table a add column y varchar;""")
+
+
+@when("we begin transaction")
+def step_begin_transaction(context):
+ """
+ Begin transaction
+ """
+ context.cli.sendline("begin;")
+
+
+@when("we rollback transaction")
+def step_rollback_transaction(context):
+ """
+ Rollback transaction
+ """
+ context.cli.sendline("rollback;")
+
+
+@then("we see table created")
+def step_see_table_created(context):
+ """
+ Wait to see create table output.
+ """
+ wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
+
+
+@then("we see record inserted")
+def step_see_record_inserted(context):
+ """
+ Wait to see insert output.
+ """
+ wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
+
+
+@then("we see record updated")
+def step_see_record_updated(context):
+ """
+ Wait to see update output.
+ """
+ wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2)
+
+
+@then("we see data selected: {data}")
+def step_see_data_selected(context, data):
+ """
+ Wait to see select output with initial or updated data.
+ """
+ x = UPDATED_DATA if data == "updated" else INITIAL_DATA
+ wrappers.expect_pager(
+ context,
+ dedent(
+ f"""\
+ +-----+\r
+ | x |\r
+ |-----|\r
+ | {x} |\r
+ +-----+\r
+ SELECT 1\r
+ """
+ ),
+ timeout=1,
+ )
+
+
+@then("we see select output without data")
+def step_see_no_data_selected(context):
+ """
+ Wait to see select output without data.
+ """
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ +---+\r
+ | x |\r
+ |---|\r
+ +---+\r
+ SELECT 0\r
+ """
+ ),
+ timeout=1,
+ )
+
+
+@then("we see record deleted")
+def step_see_data_deleted(context):
+ """
+ Wait to see delete output.
+ """
+ wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2)
+
+
+@then("we see table dropped")
+def step_see_table_dropped(context):
+ """
+ Wait to see drop output.
+ """
+ wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2)
+
+
+@then("we see transaction began")
+def step_see_transaction_began(context):
+ """
+ Wait to see transaction began.
+ """
+ wrappers.expect_pager(context, "BEGIN\r\n", timeout=2)
+
+
+@then("we see transaction rolled back")
+def step_see_transaction_rolled_back(context):
+ """
+ Wait to see transaction rollback.
+ """
+ wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2)
diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py
new file mode 100644
index 0000000..302cab9
--- /dev/null
+++ b/tests/features/steps/expanded.py
@@ -0,0 +1,70 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+from behave import when, then
+from textwrap import dedent
+import wrappers
+
+
+@when("we prepare the test data")
+def step_prepare_data(context):
+ """Create table, insert a record."""
+ context.cli.sendline("drop table if exists a;")
+ wrappers.expect_exact(
+ context,
+ "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:",
+ timeout=2,
+ )
+ context.cli.sendline("y")
+
+ wrappers.wait_prompt(context)
+ context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));")
+ wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
+ context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""")
+ wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
+
+
+@when("we set expanded {mode}")
+def step_set_expanded(context, mode):
+ """Set expanded to mode."""
+ context.cli.sendline("\\" + f"x {mode}")
+ wrappers.expect_exact(context, "Expanded display is", timeout=2)
+ wrappers.wait_prompt(context)
+
+
+@then("we see {which} data selected")
+def step_see_data(context, which):
+ """Select data from expanded test table."""
+ if which == "expanded":
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ -[ RECORD 1 ]-------------------------\r
+ x | 1\r
+ y | 1.0\r
+ z | 1.0000\r
+ SELECT 1\r
+ """
+ ),
+ timeout=1,
+ )
+ else:
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ +---+-----+--------+\r
+ | x | y | z |\r
+ |---+-----+--------|\r
+ | 1 | 1.0 | 1.0000 |\r
+ +---+-----+--------+\r
+ SELECT 1\r
+ """
+ ),
+ timeout=1,
+ )
diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py
new file mode 100644
index 0000000..a614490
--- /dev/null
+++ b/tests/features/steps/iocommands.py
@@ -0,0 +1,80 @@
+import os
+import os.path
+
+from behave import when, then
+import wrappers
+
+
+@when("we start external editor providing a file name")
+def step_edit_file(context):
+ """Edit file with external editor."""
+ context.editor_file_name = os.path.join(
+ context.package_root, "test_file_{0}.sql".format(context.conf["vi"])
+ )
+ if os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+ context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name)))
+ wrappers.expect_exact(
+ context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
+ )
+ wrappers.expect_exact(context, ":", timeout=2)
+
+
+@when("we type sql in the editor")
+def step_edit_type_sql(context):
+ context.cli.sendline("i")
+ context.cli.sendline("select * from abc")
+ context.cli.sendline(".")
+ wrappers.expect_exact(context, ":", timeout=2)
+
+
+@when("we exit the editor")
+def step_edit_quit(context):
+ context.cli.sendline("x")
+ wrappers.expect_exact(context, "written", timeout=2)
+
+
+@then("we see the sql in prompt")
+def step_edit_done_sql(context):
+ for match in "select * from abc".split(" "):
+ wrappers.expect_exact(context, match, timeout=1)
+ # Cleanup the command line.
+ context.cli.sendcontrol("c")
+ # Cleanup the edited file.
+ if context.editor_file_name and os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+ context.atprompt = True
+
+
+@when("we tee output")
+def step_tee_ouptut(context):
+ context.tee_file_name = os.path.join(
+ context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])
+ )
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+ context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name)))
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
+ wrappers.expect_exact(context, "Writing to file", timeout=5)
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
+ wrappers.expect_exact(context, "Time", timeout=5)
+
+
+@when('we query "select 123456"')
+def step_query_select_123456(context):
+ context.cli.sendline("select 123456")
+
+
+@when("we stop teeing output")
+def step_notee_output(context):
+ context.cli.sendline(r"\o")
+ wrappers.expect_exact(context, "Time", timeout=5)
+
+
+@then("we see 123456 in tee output")
+def step_see_123456_in_ouput(context):
+ with open(context.tee_file_name) as f:
+ assert "123456" in f.read()
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+ context.atprompt = True
diff --git a/tests/features/steps/named_queries.py b/tests/features/steps/named_queries.py
new file mode 100644
index 0000000..3f52859
--- /dev/null
+++ b/tests/features/steps/named_queries.py
@@ -0,0 +1,57 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+import wrappers
+
+
+@when("we save a named query")
+def step_save_named_query(context):
+ """
+ Send \ns command
+ """
+ context.cli.sendline("\\ns foo SELECT 12345")
+
+
+@when("we use a named query")
+def step_use_named_query(context):
+ """
+ Send \n command
+ """
+ context.cli.sendline("\\n foo")
+
+
+@when("we delete a named query")
+def step_delete_named_query(context):
+ """
+ Send \nd command
+ """
+ context.cli.sendline("\\nd foo")
+
+
+@then("we see the named query saved")
+def step_see_named_query_saved(context):
+ """
+ Wait to see query saved.
+ """
+ wrappers.expect_exact(context, "Saved.", timeout=2)
+
+
+@then("we see the named query executed")
+def step_see_named_query_executed(context):
+ """
+ Wait to see select output.
+ """
+ wrappers.expect_exact(context, "12345", timeout=1)
+ wrappers.expect_exact(context, "SELECT 1", timeout=1)
+
+
+@then("we see the named query deleted")
+def step_see_named_query_deleted(context):
+ """
+ Wait to see query deleted.
+ """
+ wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1)
diff --git a/tests/features/steps/pgbouncer.py b/tests/features/steps/pgbouncer.py
new file mode 100644
index 0000000..f156982
--- /dev/null
+++ b/tests/features/steps/pgbouncer.py
@@ -0,0 +1,22 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+import wrappers
+
+
+@when('we send "show help" command')
+def step_send_help_command(context):
+ context.cli.sendline("show help")
+
+
+@then("we see the pgbouncer help output")
+def see_pgbouncer_help(context):
+ wrappers.expect_exact(
+ context,
+ "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
+ timeout=3,
+ )
diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py
new file mode 100644
index 0000000..a85f371
--- /dev/null
+++ b/tests/features/steps/specials.py
@@ -0,0 +1,31 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+import wrappers
+
+
+@when("we refresh completions")
+def step_refresh_completions(context):
+ """
+ Send refresh command.
+ """
+ context.cli.sendline("\\refresh")
+
+
+@then("we see completions refresh started")
+def step_see_refresh_started(context):
+ """
+ Wait to see refresh output.
+ """
+ wrappers.expect_pager(
+ context,
+ [
+ "Auto-completion refresh started in the background.\r\n",
+ "Auto-completion refresh restarted.\r\n",
+ ],
+ timeout=2,
+ )
diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py
new file mode 100644
index 0000000..3ebcc92
--- /dev/null
+++ b/tests/features/steps/wrappers.py
@@ -0,0 +1,71 @@
+import re
+import pexpect
+from pgcli.main import COLOR_CODE_REGEX
+import textwrap
+
+from io import StringIO
+
+
+def expect_exact(context, expected, timeout):
+ timedout = False
+ try:
+ context.cli.expect_exact(expected, timeout=timeout)
+ except pexpect.TIMEOUT:
+ timedout = True
+ if timedout:
+ # Strip color codes out of the output.
+ actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
+ raise Exception(
+ textwrap.dedent(
+ """\
+ Expected:
+ ---
+ {0!r}
+ ---
+ Actual:
+ ---
+ {1!r}
+ ---
+ Full log:
+ ---
+ {2!r}
+ ---
+ """
+ ).format(expected, actual, context.logfile.getvalue())
+ )
+
+
+def expect_pager(context, expected, timeout):
+ formatted = expected if isinstance(expected, list) else [expected]
+ formatted = [
+ f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
+ for t in formatted
+ ]
+
+ expect_exact(
+ context,
+ formatted,
+ timeout=timeout,
+ )
+
+
+def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
+ """Run the process using pexpect."""
+ run_args = run_args or []
+ cli_cmd = context.conf.get("cli_command")
+ cmd_parts = [cli_cmd] + run_args
+ cmd = " ".join(cmd_parts)
+ context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
+ context.logfile = StringIO()
+ context.cli.logfile = context.logfile
+ context.exit_sent = False
+ context.currentdb = currentdb or context.conf["dbname"]
+ context.cli.sendline(r"\pset pager always")
+ if prompt_check:
+ wait_prompt(context)
+
+
+def wait_prompt(context):
+ """Make sure prompt is displayed."""
+ prompt_str = "{0}>".format(context.currentdb)
+ expect_exact(context, [prompt_str + " ", prompt_str, pexpect.EOF], timeout=3)
diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py
new file mode 100644
index 0000000..51d4909
--- /dev/null
+++ b/tests/features/wrappager.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+import sys
+
+
+def wrappager(boundary):
+ print(boundary)
+ while 1:
+ buf = sys.stdin.read(2048)
+ if not buf:
+ break
+ sys.stdout.write(buf)
+ print(boundary)
+
+
+if __name__ == "__main__":
+ wrappager(sys.argv[1])
diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py
new file mode 100644
index 0000000..9bad579
--- /dev/null
+++ b/tests/formatter/__init__.py
@@ -0,0 +1 @@
+# coding=utf-8
diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py
new file mode 100644
index 0000000..016ed95
--- /dev/null
+++ b/tests/formatter/test_sqlformatter.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+
+from pgcli.packages.formatter.sqlformatter import escape_for_sql_statement
+
+from cli_helpers.tabular_output import TabularOutputFormatter
+from pgcli.packages.formatter.sqlformatter import adapter, register_new_formatter
+
+
+def test_escape_for_sql_statement_bytes():
+ bts = b"837124ab3e8dc0f"
+ escaped_bytes = escape_for_sql_statement(bts)
+ assert escaped_bytes == "X'383337313234616233653864633066'"
+
+
+def test_escape_for_sql_statement_number():
+ num = 2981
+ escaped_bytes = escape_for_sql_statement(num)
+ assert escaped_bytes == "'2981'"
+
+
+def test_escape_for_sql_statement_str():
+ example_str = "example str"
+ escaped_bytes = escape_for_sql_statement(example_str)
+ assert escaped_bytes == "'example str'"
+
+
+def test_output_sql_insert():
+ global formatter
+ formatter = TabularOutputFormatter
+ register_new_formatter(formatter)
+ data = [
+ [
+ 1,
+ "Jackson",
+ "jackson_test@gmail.com",
+ "132454789",
+ None,
+ "2022-09-09 19:44:32.712343+08",
+ "2022-09-09 19:44:32.712343+08",
+ ]
+ ]
+ header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"]
+ table_format = "sql-insert"
+ kwargs = {
+ "column_types": [int, str, str, str, str, str, str],
+ "sep_title": "RECORD {n}",
+ "sep_character": "-",
+ "sep_length": (1, 25),
+ "missing_value": "<null>",
+ "integer_format": "",
+ "float_format": "",
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "max_field_width": 500,
+ }
+ formatter.query = 'SELECT * FROM "user";'
+ output = adapter(data, header, table_format=table_format, **kwargs)
+ output_list = [l for l in output]
+ expected = [
+ 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES',
+ " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, "
+ + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')",
+ ";",
+ ]
+ assert expected == output_list
+
+
+def test_output_sql_update():
+ global formatter
+ formatter = TabularOutputFormatter
+ register_new_formatter(formatter)
+ data = [
+ [
+ 1,
+ "Jackson",
+ "jackson_test@gmail.com",
+ "132454789",
+ "",
+ "2022-09-09 19:44:32.712343+08",
+ "2022-09-09 19:44:32.712343+08",
+ ]
+ ]
+ header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"]
+ table_format = "sql-update"
+ kwargs = {
+ "column_types": [int, str, str, str, str, str, str],
+ "sep_title": "RECORD {n}",
+ "sep_character": "-",
+ "sep_length": (1, 25),
+ "missing_value": "<null>",
+ "integer_format": "",
+ "float_format": "",
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "max_field_width": 500,
+ }
+ formatter.query = 'SELECT * FROM "user";'
+ output = adapter(data, header, table_format=table_format, **kwargs)
+ output_list = [l for l in output]
+ print(output_list)
+ expected = [
+ 'UPDATE "user" SET',
+ " \"name\" = 'Jackson'",
+ ", \"email\" = 'jackson_test@gmail.com'",
+ ", \"phone\" = '132454789'",
+ ", \"description\" = ''",
+ ", \"created_at\" = '2022-09-09 19:44:32.712343+08'",
+ ", \"updated_at\" = '2022-09-09 19:44:32.712343+08'",
+ "WHERE \"id\" = '1';",
+ ]
+ assert expected == output_list
diff --git a/tests/metadata.py b/tests/metadata.py
new file mode 100644
index 0000000..4ebcccd
--- /dev/null
+++ b/tests/metadata.py
@@ -0,0 +1,255 @@
+from functools import partial
+from itertools import product
+from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+from unittest.mock import Mock
+import pytest
+
+parametrize = pytest.mark.parametrize
+
+qual = ["if_more_than_one_table", "always"]
+no_qual = ["if_more_than_one_table", "never"]
+
+
+def escape(name):
+ if not name.islower() or name in ("select", "localtimestamp"):
+ return '"' + name + '"'
+ return name
+
+
+def completion(display_meta, text, pos=0):
+ return Completion(text, start_position=pos, display_meta=display_meta)
+
+
+def function(text, pos=0, display=None):
+ return Completion(
+ text, display=display or text, start_position=pos, display_meta="function"
+ )
+
+
+def get_result(completer, text, position=None):
+ position = len(text) if position is None else position
+ return completer.get_completions(
+ Document(text=text, cursor_position=position), Mock()
+ )
+
+
+def result_set(completer, text, position=None):
+ return set(get_result(completer, text, position))
+
+
+# The code below is quivalent to
+# def schema(text, pos=0):
+# return completion('schema', text, pos)
+# and so on
+schema = partial(completion, "schema")
+table = partial(completion, "table")
+view = partial(completion, "view")
+column = partial(completion, "column")
+keyword = partial(completion, "keyword")
+datatype = partial(completion, "datatype")
+alias = partial(completion, "table alias")
+name_join = partial(completion, "name join")
+fk_join = partial(completion, "fk join")
+join = partial(completion, "join")
+
+
+def wildcard_expansion(cols, pos=-1):
+ return Completion(cols, start_position=pos, display_meta="columns", display="*")
+
+
+class MetaData:
+ def __init__(self, metadata):
+ self.metadata = metadata
+
+ def builtin_functions(self, pos=0):
+ return [function(f, pos) for f in self.completer.functions]
+
+ def builtin_datatypes(self, pos=0):
+ return [datatype(dt, pos) for dt in self.completer.datatypes]
+
+ def keywords(self, pos=0):
+ return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()]
+
+ def specials(self, pos=0):
+ return [
+ Completion(text=k, start_position=pos, display_meta=v.description)
+ for k, v in self.completer.pgspecial.commands.items()
+ ]
+
+ def columns(self, tbl, parent="public", typ="tables", pos=0):
+ if typ == "functions":
+ fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0]
+ cols = fun[1]
+ else:
+ cols = self.metadata[typ][parent][tbl]
+ return [column(escape(col), pos) for col in cols]
+
+ def datatypes(self, parent="public", pos=0):
+ return [
+ datatype(escape(x), pos)
+ for x in self.metadata.get("datatypes", {}).get(parent, [])
+ ]
+
+ def tables(self, parent="public", pos=0):
+ return [
+ table(escape(x), pos)
+ for x in self.metadata.get("tables", {}).get(parent, [])
+ ]
+
+ def views(self, parent="public", pos=0):
+ return [
+ view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, [])
+ ]
+
+ def functions(self, parent="public", pos=0):
+ return [
+ function(
+ escape(x[0])
+ + "("
+ + ", ".join(
+ arg_name + " := "
+ for (arg_name, arg_mode) in zip(x[1], x[3])
+ if arg_mode in ("b", "i")
+ )
+ + ")",
+ pos,
+ escape(x[0])
+ + "("
+ + ", ".join(
+ arg_name
+ for (arg_name, arg_mode) in zip(x[1], x[3])
+ if arg_mode in ("b", "i")
+ )
+ + ")",
+ )
+ for x in self.metadata.get("functions", {}).get(parent, [])
+ ]
+
+ def schemas(self, pos=0):
+ schemas = {sch for schs in self.metadata.values() for sch in schs}
+ return [schema(escape(s), pos=pos) for s in schemas]
+
+ def functions_and_keywords(self, parent="public", pos=0):
+ return (
+ self.functions(parent, pos)
+ + self.builtin_functions(pos)
+ + self.keywords(pos)
+ )
+
+ # Note that the filtering parameters here only apply to the columns
+ def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0):
+ return self.functions_and_keywords(pos=pos) + self.columns(
+ tbl, parent, typ, pos
+ )
+
+ def from_clause_items(self, parent="public", pos=0):
+ return (
+ self.functions(parent, pos)
+ + self.views(parent, pos)
+ + self.tables(parent, pos)
+ )
+
+ def schemas_and_from_clause_items(self, parent="public", pos=0):
+ return self.from_clause_items(parent, pos) + self.schemas(pos)
+
+ def types(self, parent="public", pos=0):
+ return self.datatypes(parent, pos) + self.tables(parent, pos)
+
+ @property
+ def completer(self):
+ return self.get_completer()
+
+ def get_completers(self, casing):
+ """
+ Returns a function taking three bools `casing`, `filtr`, `aliasing` and
+ the list `qualify`, all defaulting to None.
+ Returns a list of completers.
+ These parameters specify the allowed values for the corresponding
+ completer parameters, `None` meaning any, i.e. (None, None, None, None)
+ results in all 24 possible completers, whereas e.g.
+ (True, False, True, ['never']) results in the one completer with
+ casing, without `search_path` filtering of objects, with table
+ aliasing, and without column qualification.
+ """
+
+ def _cfg(_casing, filtr, aliasing, qualify):
+ cfg = {"settings": {}}
+ if _casing:
+ cfg["casing"] = casing
+ cfg["settings"]["search_path_filter"] = filtr
+ cfg["settings"]["generate_aliases"] = aliasing
+ cfg["settings"]["qualify_columns"] = qualify
+ return cfg
+
+ def _cfgs(casing, filtr, aliasing, qualify):
+ casings = [True, False] if casing is None else [casing]
+ filtrs = [True, False] if filtr is None else [filtr]
+ aliases = [True, False] if aliasing is None else [aliasing]
+ qualifys = qualify or ["always", "if_more_than_one_table", "never"]
+ return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)]
+
+ def completers(casing=None, filtr=None, aliasing=None, qualify=None):
+ get_comp = self.get_completer
+ return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)]
+
+ return completers
+
+ def _make_col(self, sch, tbl, col):
+ defaults = self.metadata.get("defaults", {}).get(sch, {})
+ return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col)))
+
+ def get_completer(self, settings=None, casing=None):
+ metadata = self.metadata
+ from pgcli.pgcompleter import PGCompleter
+ from pgspecial import PGSpecial
+
+ comp = PGCompleter(
+ smart_completion=True, settings=settings, pgspecial=PGSpecial()
+ )
+
+ schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []
+
+ for sch, tbls in metadata["tables"].items():
+ schemata.append(sch)
+
+ for tbl, cols in tbls.items():
+ tables.append((sch, tbl))
+ # Let all columns be text columns
+ tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols])
+
+ for sch, tbls in metadata.get("views", {}).items():
+ for tbl, cols in tbls.items():
+ views.append((sch, tbl))
+ # Let all columns be text columns
+ view_cols.extend([self._make_col(sch, tbl, col) for col in cols])
+
+ functions = [
+ FunctionMetadata(sch, *func_meta, arg_defaults=None)
+ for sch, funcs in metadata["functions"].items()
+ for func_meta in funcs
+ ]
+
+ datatypes = [
+ (sch, typ)
+ for sch, datatypes in metadata["datatypes"].items()
+ for typ in datatypes
+ ]
+
+ foreignkeys = [
+ ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks
+ ]
+
+ comp.extend_schemata(schemata)
+ comp.extend_relations(tables, kind="tables")
+ comp.extend_relations(views, kind="views")
+ comp.extend_columns(tbl_cols, kind="tables")
+ comp.extend_columns(view_cols, kind="views")
+ comp.extend_functions(functions)
+ comp.extend_datatypes(datatypes)
+ comp.extend_foreignkeys(foreignkeys)
+ comp.set_search_path(["public"])
+ comp.extend_casing(casing or [])
+
+ return comp
diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py
new file mode 100644
index 0000000..3e89cca
--- /dev/null
+++ b/tests/parseutils/test_ctes.py
@@ -0,0 +1,137 @@
+import pytest
+from sqlparse import parse
+from pgcli.packages.parseutils.ctes import (
+ token_start_pos,
+ extract_ctes,
+ extract_column_names as _extract_column_names,
+)
+
+
+def extract_column_names(sql):
+ p = parse(sql)[0]
+ return _extract_column_names(p)
+
+
+def test_token_str_pos():
+ sql = "SELECT * FROM xxx"
+ p = parse(sql)[0]
+ idx = p.token_index(p.tokens[-1])
+ assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")
+
+ sql = "SELECT * FROM \nxxx"
+ p = parse(sql)[0]
+ idx = p.token_index(p.tokens[-1])
+ assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")
+
+
+def test_single_column_name_extraction():
+ sql = "SELECT abc FROM xxx"
+ assert extract_column_names(sql) == ("abc",)
+
+
+def test_aliased_single_column_name_extraction():
+ sql = "SELECT abc def FROM xxx"
+ assert extract_column_names(sql) == ("def",)
+
+
+def test_aliased_expression_name_extraction():
+ sql = "SELECT 99 abc FROM xxx"
+ assert extract_column_names(sql) == ("abc",)
+
+
+def test_multiple_column_name_extraction():
+ sql = "SELECT abc, def FROM xxx"
+ assert extract_column_names(sql) == ("abc", "def")
+
+
+def test_missing_column_name_handled_gracefully():
+ sql = "SELECT abc, 99 FROM xxx"
+ assert extract_column_names(sql) == ("abc",)
+
+ sql = "SELECT abc, 99, def FROM xxx"
+ assert extract_column_names(sql) == ("abc", "def")
+
+
+def test_aliased_multiple_column_name_extraction():
+ sql = "SELECT abc def, ghi jkl FROM xxx"
+ assert extract_column_names(sql) == ("def", "jkl")
+
+
+def test_table_qualified_column_name_extraction():
+ sql = "SELECT abc.def, ghi.jkl FROM xxx"
+ assert extract_column_names(sql) == ("def", "jkl")
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
+ "DELETE FROM foo WHERE x > y RETURNING x, y",
+ "UPDATE foo SET x = 9 RETURNING x, y",
+ ],
+)
+def test_extract_column_names_from_returning_clause(sql):
+ assert extract_column_names(sql) == ("x", "y")
+
+
+def test_simple_cte_extraction():
+ sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
+ start_pos = len("WITH a AS ")
+ stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
+ ctes, remainder = extract_ctes(sql)
+
+ assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
+ assert remainder.strip() == "SELECT * FROM a"
+
+
+def test_cte_extraction_around_comments():
+ sql = """--blah blah blah
+ WITH a AS (SELECT abc def FROM x)
+ SELECT * FROM a"""
+ start_pos = len(
+ """--blah blah blah
+ WITH a AS """
+ )
+ stop_pos = len(
+ """--blah blah blah
+ WITH a AS (SELECT abc def FROM x)"""
+ )
+
+ ctes, remainder = extract_ctes(sql)
+ assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
+ assert remainder.strip() == "SELECT * FROM a"
+
+
+def test_multiple_cte_extraction():
+ sql = """WITH
+ x AS (SELECT abc, def FROM x),
+ y AS (SELECT ghi, jkl FROM y)
+ SELECT * FROM a, b"""
+
+ start1 = len(
+ """WITH
+ x AS """
+ )
+
+ stop1 = len(
+ """WITH
+ x AS (SELECT abc, def FROM x)"""
+ )
+
+ start2 = len(
+ """WITH
+ x AS (SELECT abc, def FROM x),
+ y AS """
+ )
+
+ stop2 = len(
+ """WITH
+ x AS (SELECT abc, def FROM x),
+ y AS (SELECT ghi, jkl FROM y)"""
+ )
+
+ ctes, remainder = extract_ctes(sql)
+ assert tuple(ctes) == (
+ ("x", ("abc", "def"), start1, stop1),
+ ("y", ("ghi", "jkl"), start2, stop2),
+ )
diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py
new file mode 100644
index 0000000..0350e2a
--- /dev/null
+++ b/tests/parseutils/test_function_metadata.py
@@ -0,0 +1,19 @@
+from pgcli.packages.parseutils.meta import FunctionMetadata
+
+
+def test_function_metadata_eq():
+ f1 = FunctionMetadata(
+ "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
+ )
+ f2 = FunctionMetadata(
+ "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
+ )
+ f3 = FunctionMetadata(
+ "s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None
+ )
+ assert f1 == f2
+ assert f1 != f3
+ assert not (f1 != f2)
+ assert not (f1 == f3)
+ assert hash(f1) == hash(f2)
+ assert hash(f1) != hash(f3)
diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py
new file mode 100644
index 0000000..349cbd0
--- /dev/null
+++ b/tests/parseutils/test_parseutils.py
@@ -0,0 +1,310 @@
+import pytest
+from pgcli.packages.parseutils import (
+ is_destructive,
+ parse_destructive_warning,
+ BASE_KEYWORDS,
+ ALL_KEYWORDS,
+)
+from pgcli.packages.parseutils.tables import extract_tables
+from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
+
+
+def test_empty_string():
+ tables = extract_tables("")
+ assert tables == ()
+
+
+def test_simple_select_single_table():
+ tables = extract_tables("select * from abc")
+ assert tables == ((None, "abc", None, False),)
+
+
+@pytest.mark.parametrize(
+ "sql", ['select * from "abc"."def"', 'select * from abc."def"']
+)
+def test_simple_select_single_table_schema_qualified_quoted_table(sql):
+ tables = extract_tables(sql)
+ assert tables == (("abc", "def", '"def"', False),)
+
+
+@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def'])
+def test_simple_select_single_table_schema_qualified(sql):
+ tables = extract_tables(sql)
+ assert tables == (("abc", "def", None, False),)
+
+
+def test_simple_select_single_table_double_quoted():
+ tables = extract_tables('select * from "Abc"')
+ assert tables == ((None, "Abc", None, False),)
+
+
+def test_simple_select_multiple_tables():
+ tables = extract_tables("select * from abc, def")
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
+
+
+def test_simple_select_multiple_tables_double_quoted():
+ tables = extract_tables('select * from "Abc", "Def"')
+ assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)}
+
+
+def test_simple_select_single_table_deouble_quoted_aliased():
+ tables = extract_tables('select * from "Abc" a')
+ assert tables == ((None, "Abc", "a", False),)
+
+
+def test_simple_select_multiple_tables_deouble_quoted_aliased():
+ tables = extract_tables('select * from "Abc" a, "Def" d')
+ assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)}
+
+
+def test_simple_select_multiple_tables_schema_qualified():
+ tables = extract_tables("select * from abc.def, ghi.jkl")
+ assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)}
+
+
+def test_simple_select_with_cols_single_table():
+ tables = extract_tables("select a,b from abc")
+ assert tables == ((None, "abc", None, False),)
+
+
+def test_simple_select_with_cols_single_table_schema_qualified():
+ tables = extract_tables("select a,b from abc.def")
+ assert tables == (("abc", "def", None, False),)
+
+
+def test_simple_select_with_cols_multiple_tables():
+ tables = extract_tables("select a,b from abc, def")
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
+
+
+def test_simple_select_with_cols_multiple_qualified_tables():
+ tables = extract_tables("select a,b from abc.def, def.ghi")
+ assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)}
+
+
+def test_select_with_hanging_comma_single_table():
+ tables = extract_tables("select a, from abc")
+ assert tables == ((None, "abc", None, False),)
+
+
+def test_select_with_hanging_comma_multiple_tables():
+ tables = extract_tables("select a, from abc, def")
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
+
+
+def test_select_with_hanging_period_multiple_tables():
+ tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
+ assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)}
+
+
+def test_simple_insert_single_table():
+ tables = extract_tables('insert into abc (id, name) values (1, "def")')
+
+ # sqlparse mistakenly assigns an alias to the table
+ # AND mistakenly identifies the field list as
+ # assert tables == ((None, 'abc', 'abc', False),)
+
+ assert tables == ((None, "abc", "abc", False),)
+
+
+@pytest.mark.xfail
+def test_simple_insert_single_table_schema_qualified():
+ tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
+ assert tables == (("abc", "def", None, False),)
+
+
+def test_simple_update_table_no_schema():
+ tables = extract_tables("update abc set id = 1")
+ assert tables == ((None, "abc", None, False),)
+
+
+def test_simple_update_table_with_schema():
+ tables = extract_tables("update abc.def set id = 1")
+ assert tables == (("abc", "def", None, False),)
+
+
+@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
+def test_join_table(join_type):
+ sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num"
+ tables = extract_tables(sql)
+ assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)}
+
+
+def test_join_table_schema_qualified():
+ tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
+ assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)}
+
+
+def test_incomplete_join_clause():
+ sql = """select a.x, b.y
+ from abc a join bcd b
+ on a.id = """
+ tables = extract_tables(sql)
+ assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False))
+
+
+def test_join_as_table():
+ tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
+ assert tables == ((None, "my_table", "m", False),)
+
+
+def test_multiple_joins():
+ sql = """select * from t1
+ inner join t2 ON
+ t1.id = t2.t1_id
+ inner join t3 ON
+ t2.id = t3."""
+ tables = extract_tables(sql)
+ assert tables == (
+ (None, "t1", None, False),
+ (None, "t2", None, False),
+ (None, "t3", None, False),
+ )
+
+
+def test_subselect_tables():
+ sql = "SELECT * FROM (SELECT FROM abc"
+ tables = extract_tables(sql)
+ assert tables == ((None, "abc", None, False),)
+
+
+@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"])
+def test_extract_no_tables(text):
+ tables = extract_tables(text)
+ assert tables == tuple()
+
+
+@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
+def test_simple_function_as_table(arg_list):
+ tables = extract_tables(f"SELECT * FROM foo({arg_list})")
+ assert tables == ((None, "foo", None, True),)
+
+
+@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
+def test_simple_schema_qualified_function_as_table(arg_list):
+ tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})")
+ assert tables == (("foo", "bar", None, True),)
+
+
+@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
+def test_simple_aliased_function_as_table(arg_list):
+ tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar")
+ assert tables == ((None, "foo", "bar", True),)
+
+
+def test_simple_table_and_function():
+ tables = extract_tables("SELECT * FROM foo JOIN bar()")
+ assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)}
+
+
+def test_complex_table_and_function():
+ tables = extract_tables(
+ """SELECT * FROM foo.bar baz
+ JOIN bar.qux(x, y, z) quux"""
+ )
+ assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)}
+
+
+def test_find_prev_keyword_using():
+ q = "select * from tbl1 inner join tbl2 using (col1, "
+ kw, q2 = find_prev_keyword(q)
+ assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using ("
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select * from foo where bar",
+ "select * from foo where bar = 1 and baz or ",
+ "select * from foo where bar = 1 and baz between qux and ",
+ ],
+)
+def test_find_prev_keyword_where(sql):
+ kw, stripped = find_prev_keyword(sql)
+ assert kw.value == "where" and stripped == "select * from foo where"
+
+
+@pytest.mark.parametrize(
+ "sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "]
+)
+def test_find_prev_keyword_open_parens(sql):
+ kw, _ = find_prev_keyword(sql)
+ assert kw.value == "("
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "",
+ "$$ foo $$",
+ "$$ 'foo' $$",
+ '$$ "foo" $$',
+ "$$ $a$ $$",
+ "$a$ $$ $a$",
+ "foo bar $$ baz $$",
+ ],
+)
+def test_is_open_quote__closed(sql):
+ assert not is_open_quote(sql)
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "$$",
+ ";;;$$",
+ "foo $$ bar $$; foo $$",
+ "$$ foo $a$",
+ "foo 'bar baz",
+ "$a$ foo ",
+ '$$ "foo" ',
+ "$$ $a$ ",
+ "foo bar $$ baz",
+ ],
+)
+def test_is_open_quote__open(sql):
+ assert is_open_quote(sql)
+
+
+@pytest.mark.parametrize(
+ ("sql", "keywords", "expected"),
+ [
+ ("update abc set x = 1", ALL_KEYWORDS, True),
+ ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True),
+ ("update abc set x = 1", BASE_KEYWORDS, True),
+ ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False),
+ ("select x, y, z from abc", ALL_KEYWORDS, False),
+ ("drop abc", ALL_KEYWORDS, True),
+ ("alter abc", ALL_KEYWORDS, True),
+ ("delete abc", ALL_KEYWORDS, True),
+ ("truncate abc", ALL_KEYWORDS, True),
+ ("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False),
+ ("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False),
+ ("insert into abc values (1, 2, 3)", ["insert"], True),
+ ("insert into abc values (1, 2, 3)", ["insert"], True),
+ ],
+)
+def test_is_destructive(sql, keywords, expected):
+ assert is_destructive(sql, keywords) == expected
+
+
+@pytest.mark.parametrize(
+ ("warning_level", "expected"),
+ [
+ ("true", ALL_KEYWORDS),
+ ("false", []),
+ ("all", ALL_KEYWORDS),
+ ("moderate", BASE_KEYWORDS),
+ ("off", []),
+ ("", []),
+ (None, []),
+ (ALL_KEYWORDS, ALL_KEYWORDS),
+ (BASE_KEYWORDS, BASE_KEYWORDS),
+ ("insert", ["insert"]),
+ ("drop,alter,delete", ["drop", "alter", "delete"]),
+ (["drop", "alter", "delete"], ["drop", "alter", "delete"]),
+ ],
+)
+def test_parse_destructive_warning(warning_level, expected):
+ assert parse_destructive_warning(warning_level) == expected
diff --git a/tests/pytest.ini b/tests/pytest.ini
new file mode 100644
index 0000000..f787740
--- /dev/null
+++ b/tests/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts=--capture=sys --showlocals \ No newline at end of file
diff --git a/tests/test_auth.py b/tests/test_auth.py
new file mode 100644
index 0000000..a517a89
--- /dev/null
+++ b/tests/test_auth.py
@@ -0,0 +1,40 @@
+import pytest
+from unittest import mock
+from pgcli import auth
+
+
+@pytest.mark.parametrize("enabled,call_count", [(True, 1), (False, 0)])
+def test_keyring_initialize(enabled, call_count):
+ logger = mock.MagicMock()
+
+ with mock.patch("importlib.import_module", return_value=True) as import_method:
+ auth.keyring_initialize(enabled, logger=logger)
+ assert import_method.call_count == call_count
+
+
+def test_keyring_get_password_ok():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch("pgcli.auth.keyring.get_password", return_value="abc123"):
+ assert auth.keyring_get_password("test") == "abc123"
+
+
+def test_keyring_get_password_exception():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch(
+ "pgcli.auth.keyring.get_password", side_effect=Exception("Boom!")
+ ):
+ assert auth.keyring_get_password("test") == ""
+
+
+def test_keyring_set_password_ok():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch("pgcli.auth.keyring.set_password"):
+ auth.keyring_set_password("test", "abc123")
+
+
+def test_keyring_set_password_exception():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch(
+ "pgcli.auth.keyring.set_password", side_effect=Exception("Boom!")
+ ):
+ auth.keyring_set_password("test", "abc123")
diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py
new file mode 100644
index 0000000..a5529d6
--- /dev/null
+++ b/tests/test_completion_refresher.py
@@ -0,0 +1,95 @@
+import time
+import pytest
+from unittest.mock import Mock, patch
+
+
+@pytest.fixture
+def refresher():
+ from pgcli.completion_refresher import CompletionRefresher
+
+ return CompletionRefresher()
+
+
+def test_ctor(refresher):
+ """
+ Refresher object should contain a few handlers
+ :param refresher:
+ :return:
+ """
+ assert len(refresher.refreshers) > 0
+ actual_handlers = list(refresher.refreshers.keys())
+ expected_handlers = [
+ "schemata",
+ "tables",
+ "views",
+ "types",
+ "databases",
+ "casing",
+ "functions",
+ ]
+ assert expected_handlers == actual_handlers
+
+
+def test_refresh_called_once(refresher):
+ """
+
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
+ special = Mock()
+
+ with patch.object(refresher, "_bg_refresh") as bg_refresh:
+ actual = refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual) == 1
+ assert len(actual[0]) == 4
+ assert actual[0][3] == "Auto-completion refresh started in the background."
+ bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None)
+
+
+def test_refresh_called_twice(refresher):
+ """
+ If refresh is called a second time, it should be restarted
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
+ special = Mock()
+
+ def dummy_bg_refresh(*args):
+ time.sleep(3) # seconds
+
+ refresher._bg_refresh = dummy_bg_refresh
+
+ actual1 = refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual1) == 1
+ assert len(actual1[0]) == 4
+ assert actual1[0][3] == "Auto-completion refresh started in the background."
+
+ actual2 = refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual2) == 1
+ assert len(actual2[0]) == 4
+ assert actual2[0][3] == "Auto-completion refresh restarted."
+
+
+def test_refresh_with_callbacks(refresher):
+ """
+ Callbacks must be called
+ :param refresher:
+ """
+ callbacks = [Mock()]
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
+ pgexecute.extra_args = {}
+ special = Mock()
+
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert callbacks[0].call_count == 1
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 0000000..08fe74e
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,43 @@
+import io
+import os
+import stat
+
+import pytest
+
+from pgcli.config import ensure_dir_exists, skip_initial_comment
+
+
+def test_ensure_file_parent(tmpdir):
+ subdir = tmpdir.join("subdir")
+ rcfile = subdir.join("rcfile")
+ ensure_dir_exists(str(rcfile))
+
+
+def test_ensure_existing_dir(tmpdir):
+ rcfile = str(tmpdir.mkdir("subdir").join("rcfile"))
+
+ # should just not raise
+ ensure_dir_exists(rcfile)
+
+
+def test_ensure_other_create_error(tmpdir):
+ subdir = tmpdir.join('subdir"')
+ rcfile = subdir.join("rcfile")
+
+ # trigger an oserror that isn't "directory already exists"
+ os.chmod(str(tmpdir), stat.S_IREAD)
+
+ with pytest.raises(OSError):
+ ensure_dir_exists(str(rcfile))
+
+
+@pytest.mark.parametrize(
+ "text, skipped_lines",
+ (
+ ("abc\n", 1),
+ ("#[section]\ndef\n[section]", 2),
+ ("[section]", 0),
+ ),
+)
+def test_skip_initial_comment(text, skipped_lines):
+ assert skip_initial_comment(io.StringIO(text)) == skipped_lines
diff --git a/tests/test_fuzzy_completion.py b/tests/test_fuzzy_completion.py
new file mode 100644
index 0000000..8f8f2cd
--- /dev/null
+++ b/tests/test_fuzzy_completion.py
@@ -0,0 +1,87 @@
+import pytest
+
+
+@pytest.fixture
+def completer():
+ import pgcli.pgcompleter as pgcompleter
+
+ return pgcompleter.PGCompleter()
+
+
+def test_ranking_ignores_identifier_quotes(completer):
+ """When calculating result rank, identifier quotes should be ignored.
+
+ The result ranking algorithm ignores identifier quotes. Without this
+ correction, the match "user", which Postgres requires to be quoted
+ since it is also a reserved word, would incorrectly fall below the
+ match user_action because the literal quotation marks in "user"
+ alter the position of the match.
+
+ This test checks that the fuzzy ranking algorithm correctly ignores
+ quotation marks when computing match ranks.
+
+ """
+
+ text = "user"
+ collection = ["user_action", '"user"']
+ matches = completer.find_matches(text, collection)
+ assert len(matches) == 2
+
+
+def test_ranking_based_on_shortest_match(completer):
+ """Fuzzy result rank should be based on shortest match.
+
+ Result ranking in fuzzy searching is partially based on the length
+ of matches: shorter matches are considered more relevant than
+ longer ones. When searching for the text 'user', the length
+ component of the match 'user_group' could be either 4 ('user') or
+ 7 ('user_gr').
+
+ This test checks that the fuzzy ranking algorithm uses the shorter
+ match when calculating result rank.
+
+ """
+
+ text = "user"
+ collection = ["api_user", "user_group"]
+ matches = completer.find_matches(text, collection)
+
+ assert matches[1].priority > matches[0].priority
+
+
+@pytest.mark.parametrize(
+ "collection",
+ [["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]],
+)
+def test_should_break_ties_using_lexical_order(completer, collection):
+ """Fuzzy result rank should use lexical order to break ties.
+
+ When fuzzy matching, if multiple matches have the same match length and
+ start position, present them in lexical (rather than arbitrary) order. For
+ example, if we have tables 'user', 'user_action', and 'user_group', a
+ search for the text 'user' should present these tables in this order.
+
+ The input collections to this test are out of order; each run checks that
+ the search text 'user' results in the input tables being reordered
+ lexically.
+
+ """
+
+ text = "user"
+ matches = completer.find_matches(text, collection)
+
+ assert matches[1].priority > matches[0].priority
+
+
+def test_matching_should_be_case_insensitive(completer):
+ """Fuzzy matching should keep matches even if letter casing doesn't match.
+
+ This test checks that variations of the text which have different casing
+ are still matched.
+ """
+
+ text = "foo"
+ collection = ["Foo", "FOO", "fOO"]
+ matches = completer.find_matches(text, collection)
+
+ assert len(matches) == 3
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..cbf20a6
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,490 @@
+import os
+import platform
+from unittest import mock
+
+import pytest
+
+try:
+ import setproctitle
+except ImportError:
+ setproctitle = None
+
+from pgcli.main import (
+ obfuscate_process_password,
+ format_output,
+ PGCli,
+ OutputSettings,
+ COLOR_CODE_REGEX,
+)
+from pgcli.pgexecute import PGExecute
+from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS
+from utils import dbtest, run
+from collections import namedtuple
+
+
+@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows")
+@pytest.mark.skipif(not setproctitle, reason="setproctitle not available")
+def test_obfuscate_process_password():
+ original_title = setproctitle.getproctitle()
+
+ setproctitle.setproctitle("pgcli user=root password=secret host=localhost")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli user=root password=xxxx host=localhost"
+ assert title == expected
+
+ setproctitle.setproctitle("pgcli user=root password=top secret host=localhost")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli user=root password=xxxx host=localhost"
+ assert title == expected
+
+ setproctitle.setproctitle("pgcli user=root password=top secret")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli user=root password=xxxx"
+ assert title == expected
+
+ setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli postgres://root:xxxx@localhost/db"
+ assert title == expected
+
+ setproctitle.setproctitle(original_title)
+
+
+def test_format_output():
+ settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g")
+ results = format_output(
+ "Title", [("abc", "def")], ["head1", "head2"], "test status", settings
+ )
+ expected = [
+ "Title",
+ "+-------+-------+",
+ "| head1 | head2 |",
+ "|-------+-------|",
+ "| abc | def |",
+ "+-------+-------+",
+ "test status",
+ ]
+ assert list(results) == expected
+
+
+def test_format_output_truncate_on():
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10
+ )
+ results = format_output(
+ None,
+ [("first field value", "second field value")],
+ ["head1", "head2"],
+ None,
+ settings,
+ )
+ expected = [
+ "+------------+------------+",
+ "| head1 | head2 |",
+ "|------------+------------|",
+ "| first f... | second ... |",
+ "+------------+------------+",
+ ]
+ assert list(results) == expected
+
+
+def test_format_output_truncate_off():
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None
+ )
+ long_field_value = ("first field " * 100).strip()
+ results = format_output(None, [(long_field_value,)], ["head1"], None, settings)
+ lines = list(results)
+ assert lines[3] == f"| {long_field_value} |"
+
+
+@dbtest
+def test_format_array_output(executor):
+ statement = """
+ SELECT
+ array[1, 2, 3]::bigint[] as bigint_array,
+ '{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
+ '{å,魚,текст}'::text[] as 配列
+ UNION ALL
+ SELECT '{}', NULL, array[NULL]
+ """
+ results = run(executor, statement)
+ expected = [
+ "+--------------+----------------------+--------------+",
+ "| bigint_array | nested_numeric_array | 配列 |",
+ "|--------------+----------------------+--------------|",
+ "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |",
+ "| {} | <null> | {<null>} |",
+ "+--------------+----------------------+--------------+",
+ "SELECT 2",
+ ]
+ assert list(results) == expected
+
+
+@dbtest
+def test_format_array_output_expanded(executor):
+ statement = """
+ SELECT
+ array[1, 2, 3]::bigint[] as bigint_array,
+ '{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
+ '{å,魚,текст}'::text[] as 配列
+ UNION ALL
+ SELECT '{}', NULL, array[NULL]
+ """
+ results = run(executor, statement, expanded=True)
+ expected = [
+ "-[ RECORD 1 ]-------------------------",
+ "bigint_array | {1,2,3}",
+ "nested_numeric_array | {{1,2},{3,4}}",
+ "配列 | {å,魚,текст}",
+ "-[ RECORD 2 ]-------------------------",
+ "bigint_array | {}",
+ "nested_numeric_array | <null>",
+ "配列 | {<null>}",
+ "SELECT 2",
+ ]
+ assert "\n".join(results) == "\n".join(expected)
+
+
+def test_format_output_auto_expand():
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100
+ )
+ table_results = format_output(
+ "Title", [("abc", "def")], ["head1", "head2"], "test status", settings
+ )
+ table = [
+ "Title",
+ "+-------+-------+",
+ "| head1 | head2 |",
+ "|-------+-------|",
+ "| abc | def |",
+ "+-------+-------+",
+ "test status",
+ ]
+ assert list(table_results) == table
+ expanded_results = format_output(
+ "Title",
+ [("abc", "def")],
+ ["head1", "head2"],
+ "test status",
+ settings._replace(max_width=1),
+ )
+ expanded = [
+ "Title",
+ "-[ RECORD 1 ]-------------------------",
+ "head1 | abc",
+ "head2 | def",
+ "test status",
+ ]
+ assert "\n".join(expanded_results) == "\n".join(expanded)
+
+
+termsize = namedtuple("termsize", ["rows", "columns"])
+test_line = "-" * 10
+test_data = [
+ (10, 10, "\n".join([test_line] * 7)),
+ (10, 10, "\n".join([test_line] * 6)),
+ (10, 10, "\n".join([test_line] * 5)),
+ (10, 10, "-" * 11),
+ (10, 10, "-" * 10),
+ (10, 10, "-" * 9),
+]
+
+# 4 lines are reserved at the bottom of the terminal for pgcli's prompt
+use_pager_when_on = [True, True, False, True, False, False]
+
+# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL
+test_ids = [
+ "Output longer than terminal height",
+ "Output equal to terminal height",
+ "Output shorter than terminal height",
+ "Output longer than terminal width",
+ "Output equal to terminal width",
+ "Output shorter than terminal width",
+]
+
+
+@pytest.fixture
+def pset_pager_mocks():
+ cli = PGCli()
+ cli.watch_command = None
+ with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch(
+ "pgcli.main.click.echo_via_pager"
+ ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app:
+ yield cli, mock_echo, mock_echo_via_pager, mock_app
+
+
+@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
+def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks):
+ cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
+ mock_cli.output.get_size.return_value = termsize(
+ rows=term_height, columns=term_width
+ )
+
+ with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF):
+ cli.echo_via_pager(text)
+
+ mock_echo.assert_called()
+ mock_echo_via_pager.assert_not_called()
+
+
+@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
+def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
+ cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
+ mock_cli.output.get_size.return_value = termsize(
+ rows=term_height, columns=term_width
+ )
+
+ with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS):
+ cli.echo_via_pager(text)
+
+ mock_echo.assert_not_called()
+ mock_echo_via_pager.assert_called()
+
+
+pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)]
+
+
+@pytest.mark.parametrize(
+ "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids
+)
+def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks):
+ cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
+ mock_cli.output.get_size.return_value = termsize(
+ rows=term_height, columns=term_width
+ )
+
+ with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT):
+ cli.echo_via_pager(text)
+
+ if use_pager:
+ mock_echo.assert_not_called()
+ mock_echo_via_pager.assert_called()
+ else:
+ mock_echo_via_pager.assert_not_called()
+ mock_echo.assert_called()
+
+
+@pytest.mark.parametrize(
+ "text,expected_length",
+ [
+ (
+ "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s",
+ 78,
+ ),
+ ("=\u001b[m=", 2),
+ ("-\u001b]23\u0007-", 2),
+ ],
+)
+def test_color_pattern(text, expected_length, pset_pager_mocks):
+ cli = pset_pager_mocks[0]
+ assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length
+
+
+@dbtest
+def test_i_works(tmpdir, executor):
+ sqlfile = tmpdir.join("test.sql")
+ sqlfile.write("SELECT NOW()")
+ rcfile = str(tmpdir.join("rcfile"))
+ cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
+ statement = r"\i {0}".format(sqlfile)
+ run(executor, statement, pgspecial=cli.pgspecial)
+
+
+@dbtest
+def test_echo_works(executor):
+ cli = PGCli(pgexecute=executor)
+ statement = r"\echo asdf"
+ result = run(executor, statement, pgspecial=cli.pgspecial)
+ assert result == ["asdf"]
+
+
+@dbtest
+def test_qecho_works(executor):
+ cli = PGCli(pgexecute=executor)
+ statement = r"\qecho asdf"
+ result = run(executor, statement, pgspecial=cli.pgspecial)
+ assert result == ["asdf"]
+
+
+@dbtest
+def test_watch_works(executor):
+ cli = PGCli(pgexecute=executor)
+
+ def run_with_watch(
+ query, target_call_count=1, expected_output="", expected_timing=None
+ ):
+ """
+ :param query: Input to the CLI
+ :param target_call_count: Number of times the user lets the command run before Ctrl-C
+ :param expected_output: Substring expected to be found for each executed query
+ :param expected_timing: value `time.sleep` expected to be called with on every invocation
+ """
+ with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch(
+ "pgcli.main.sleep"
+ ) as mock_sleep:
+ mock_sleep.side_effect = [None] * (target_call_count - 1) + [
+ KeyboardInterrupt
+ ]
+ cli.handle_watch_command(query)
+ # Validate that sleep was called with the right timing
+ for i in range(target_call_count - 1):
+ assert mock_sleep.call_args_list[i][0][0] == expected_timing
+ # Validate that the output of the query was expected
+ assert mock_echo.call_count == target_call_count
+ for i in range(target_call_count):
+ assert expected_output in mock_echo.call_args_list[i][0][0]
+
+ # With no history, it errors.
+ with mock.patch("pgcli.main.click.secho") as mock_secho:
+ cli.handle_watch_command(r"\watch 2")
+ mock_secho.assert_called()
+ assert (
+ r"\watch cannot be used with an empty query"
+ in mock_secho.call_args_list[0][0][0]
+ )
+
+ # Usage 1: Run a query and then re-run it with \watch across two prompts.
+ run_with_watch("SELECT 111", expected_output="111")
+ run_with_watch(
+ "\\watch 10", target_call_count=2, expected_output="111", expected_timing=10
+ )
+
+ # Usage 2: Run a query and \watch via the same prompt.
+ run_with_watch(
+ "SELECT 222; \\watch 4",
+ target_call_count=3,
+ expected_output="222",
+ expected_timing=4,
+ )
+
+ # Usage 3: Re-run the last watched command with a new timing
+ run_with_watch(
+ "\\watch 5", target_call_count=4, expected_output="222", expected_timing=5
+ )
+
+
+def test_missing_rc_dir(tmpdir):
+ rcfile = str(tmpdir.join("subdir").join("rcfile"))
+
+ PGCli(pgclirc_file=rcfile)
+ assert os.path.exists(rcfile)
+
+
+def test_quoted_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B")
+ mock_connect.assert_called_with(
+ database="testdb[", host="baz.com", user="bar^", passwd="]foo"
+ )
+
+
+def test_pg_service_file(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
+ service_conf.write(
+ """File begins with a comment
+ that is not a comment
+ # or maybe a comment after all
+ because psql is crazy
+
+ [myservice]
+ host=a_host
+ user=a_user
+ port=5433
+ password=much_secure
+ dbname=a_dbname
+
+ [my_other_service]
+ host=b_host
+ user=b_user
+ port=5435
+ dbname=b_dbname
+ """
+ )
+ os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath
+ cli.connect_service("myservice", "another_user")
+ mock_connect.assert_called_with(
+ database="a_dbname",
+ host="a_host",
+ user="another_user",
+ port="5433",
+ passwd="much_secure",
+ )
+
+ with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
+ mock_pgexecute.return_value = None
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ os.environ["PGPASSWORD"] = "very_secure"
+ cli.connect_service("my_other_service", None)
+ mock_pgexecute.assert_called_with(
+ "b_dbname",
+ "b_user",
+ "very_secure",
+ "b_host",
+ "5435",
+ "",
+ application_name="pgcli",
+ )
+ del os.environ["PGPASSWORD"]
+ del os.environ["PGSERVICEFILE"]
+
+
+def test_ssl_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri(
+ "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?"
+ "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem"
+ )
+ mock_connect.assert_called_with(
+ database="testdb[",
+ host="baz.com",
+ user="bar^",
+ passwd="]foo",
+ sslmode="verify-full",
+ sslcert="my.pem",
+ sslkey="my-key.pem",
+ sslrootcert="ca.pem",
+ )
+
+
+def test_port_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb")
+ mock_connect.assert_called_with(
+ database="testdb", host="baz.com", user="bar", passwd="foo", port="2543"
+ )
+
+
+def test_multihost_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri(
+ "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb"
+ )
+ mock_connect.assert_called_with(
+ database="testdb",
+ host="baz1.com,baz2.com,baz3.com",
+ user="bar",
+ passwd="foo",
+ port="2543,2543,2543",
+ )
+
+
+def test_application_name_db_uri(tmpdir):
+ with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
+ mock_pgexecute.return_value = None
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri("postgres://bar@baz.com/?application_name=cow")
+ mock_pgexecute.assert_called_with(
+ "bar", "bar", "", "baz.com", "", "", application_name="cow"
+ )
diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py
new file mode 100644
index 0000000..5b93661
--- /dev/null
+++ b/tests/test_naive_completion.py
@@ -0,0 +1,133 @@
+import pytest
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+from utils import completions_to_set
+
+
+@pytest.fixture
+def completer():
+ import pgcli.pgcompleter as pgcompleter
+
+ return pgcompleter.PGCompleter(smart_completion=False)
+
+
+@pytest.fixture
+def complete_event():
+ from unittest.mock import Mock
+
+ return Mock()
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ""
+ position = 0
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set(map(Completion, completer.all_completions))
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = "SEL"
+ position = len("SEL")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set([Completion(text="SELECT", start_position=-3)])
+
+
+def test_function_name_completion(completer, complete_event):
+ text = "SELECT MA"
+ position = len("SELECT MA")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set(
+ [
+ Completion(text="MATERIALIZED VIEW", start_position=-2),
+ Completion(text="MAX", start_position=-2),
+ Completion(text="MAXEXTENTS", start_position=-2),
+ Completion(text="MAKE_DATE", start_position=-2),
+ Completion(text="MAKE_TIME", start_position=-2),
+ Completion(text="MAKE_TIMESTAMPTZ", start_position=-2),
+ Completion(text="MAKE_INTERVAL", start_position=-2),
+ Completion(text="MASKLEN", start_position=-2),
+ Completion(text="MAKE_TIMESTAMP", start_position=-2),
+ ]
+ )
+
+
+def test_column_name_completion(completer, complete_event):
+ text = "SELECT FROM users"
+ position = len("SELECT ")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set(map(Completion, completer.all_completions))
+
+
+def test_alter_well_known_keywords_completion(completer, complete_event):
+ text = "ALTER "
+ position = len(text)
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event,
+ smart_completion=True,
+ )
+ )
+ assert result > completions_to_set(
+ [
+ Completion(text="DATABASE", display_meta="keyword"),
+ Completion(text="TABLE", display_meta="keyword"),
+ Completion(text="SYSTEM", display_meta="keyword"),
+ ]
+ )
+ assert (
+ completions_to_set([Completion(text="CREATE", display_meta="keyword")])
+ not in result
+ )
+
+
+def test_special_name_completion(completer, complete_event):
+ text = "\\"
+ position = len("\\")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ # Special commands will NOT be suggested during naive completion mode.
+ assert result == completions_to_set([])
+
+
+def test_datatype_name_completion(completer, complete_event):
+ text = "SELECT price::IN"
+ position = len("SELECT price::IN")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event,
+ smart_completion=True,
+ )
+ )
+ assert result == completions_to_set(
+ [
+ Completion(text="INET", display_meta="datatype"),
+ Completion(text="INT", display_meta="datatype"),
+ Completion(text="INT2", display_meta="datatype"),
+ Completion(text="INT4", display_meta="datatype"),
+ Completion(text="INT8", display_meta="datatype"),
+ Completion(text="INTEGER", display_meta="datatype"),
+ Completion(text="INTERNAL", display_meta="datatype"),
+ Completion(text="INTERVAL", display_meta="datatype"),
+ ]
+ )
diff --git a/tests/test_pgcompleter.py b/tests/test_pgcompleter.py
new file mode 100644
index 0000000..909fa0b
--- /dev/null
+++ b/tests/test_pgcompleter.py
@@ -0,0 +1,76 @@
+import pytest
+from pgcli import pgcompleter
+
+
+def test_load_alias_map_file_missing_file():
+ with pytest.raises(
+ pgcompleter.InvalidMapFile,
+ match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$",
+ ):
+ pgcompleter.load_alias_map_file("/path/to/non-existent/file.json")
+
+
+def test_load_alias_map_file_invalid_json(tmp_path):
+ fpath = tmp_path / "foo.json"
+ fpath.write_text("this is not valid json")
+ with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"):
+ pgcompleter.load_alias_map_file(str(fpath))
+
+
+@pytest.mark.parametrize(
+ "table_name, alias",
+ [
+ ("SomE_Table", "SET"),
+ ("SOmeTabLe", "SOTL"),
+ ("someTable", "T"),
+ ],
+)
+def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias):
+ assert pgcompleter.generate_alias(table_name) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias",
+ [
+ ("some_tab_le", "stl"),
+ ("s_ome_table", "sot"),
+ ("sometable", "s"),
+ ],
+)
+def test_generate_alias_uses_first_char_and_every_preceded_by_underscore(
+ table_name, alias
+):
+ assert pgcompleter.generate_alias(table_name) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias_map, alias",
+ [
+ ("some_table", {"some_table": "my_alias"}, "my_alias"),
+ ],
+)
+def test_generate_alias_can_use_alias_map(table_name, alias_map, alias):
+ assert pgcompleter.generate_alias(table_name, alias_map) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias_map, alias",
+ [
+ ("SomeTable", {"SomeTable": "my_alias"}, "my_alias"),
+ ],
+)
+def test_generate_alias_prefers_alias_over_upper_case_name(
+ table_name, alias_map, alias
+):
+ assert pgcompleter.generate_alias(table_name, alias_map) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias",
+ [
+ ("Some_tablE", "SE"),
+ ("SomeTab_le", "ST"),
+ ],
+)
+def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias):
+ assert pgcompleter.generate_alias(table_name) == alias
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
new file mode 100644
index 0000000..636795b
--- /dev/null
+++ b/tests/test_pgexecute.py
@@ -0,0 +1,773 @@
+from textwrap import dedent
+
+import psycopg
+import pytest
+from unittest.mock import patch, MagicMock
+from pgspecial.main import PGSpecial, NO_QUERY
+from utils import run, dbtest, requires_json, requires_jsonb
+
+from pgcli.main import PGCli
+from pgcli.packages.parseutils.meta import FunctionMetadata
+
+
+def function_meta_data(
+ func_name,
+ schema_name="public",
+ arg_names=None,
+ arg_types=None,
+ arg_modes=None,
+ return_type=None,
+ is_aggregate=False,
+ is_window=False,
+ is_set_returning=False,
+ is_extension=False,
+ arg_defaults=None,
+):
+ return FunctionMetadata(
+ schema_name,
+ func_name,
+ arg_names,
+ arg_types,
+ arg_modes,
+ return_type,
+ is_aggregate,
+ is_window,
+ is_set_returning,
+ is_extension,
+ arg_defaults,
+ )
+
+
+@dbtest
+def test_conn(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
+ assert run(executor, """select * from test""", join=True) == dedent(
+ """\
+ +-----+
+ | a |
+ |-----|
+ | abc |
+ +-----+
+ SELECT 1"""
+ )
+
+
+@dbtest
+def test_copy(executor):
+ executor_copy = executor.copy()
+ run(executor_copy, """create table test(a text)""")
+ run(executor_copy, """insert into test values('abc')""")
+ assert run(executor_copy, """select * from test""", join=True) == dedent(
+ """\
+ +-----+
+ | a |
+ |-----|
+ | abc |
+ +-----+
+ SELECT 1"""
+ )
+
+
+@dbtest
+def test_bools_are_treated_as_strings(executor):
+ run(executor, """create table test(a boolean)""")
+ run(executor, """insert into test values(True)""")
+ assert run(executor, """select * from test""", join=True) == dedent(
+ """\
+ +------+
+ | a |
+ |------|
+ | True |
+ +------+
+ SELECT 1"""
+ )
+
+
+@dbtest
+def test_expanded_slash_G(executor, pgspecial):
+ # Tests whether we reset the expanded output after a \G.
+ run(executor, """create table test(a boolean)""")
+ run(executor, """insert into test values(True)""")
+ results = run(executor, r"""select * from test \G""", pgspecial=pgspecial)
+ assert pgspecial.expanded_output == False
+
+
+@dbtest
+def test_schemata_table_views_and_columns_query(executor):
+ run(executor, "create table a(x text, y text)")
+ run(executor, "create table b(z text)")
+ run(executor, "create view d as select 1 as e")
+ run(executor, "create schema schema1")
+ run(executor, "create table schema1.c (w text DEFAULT 'meow')")
+ run(executor, "create schema schema2")
+
+ # schemata
+ # don't enforce all members of the schemas since they may include postgres
+ # temporary schemas
+ assert set(executor.schemata()) >= {
+ "public",
+ "pg_catalog",
+ "information_schema",
+ "schema1",
+ "schema2",
+ }
+ assert executor.search_path() == ["pg_catalog", "public"]
+
+ # tables
+ assert set(executor.tables()) >= {
+ ("public", "a"),
+ ("public", "b"),
+ ("schema1", "c"),
+ }
+
+ assert set(executor.table_columns()) >= {
+ ("public", "a", "x", "text", False, None),
+ ("public", "a", "y", "text", False, None),
+ ("public", "b", "z", "text", False, None),
+ ("schema1", "c", "w", "text", True, "'meow'::text"),
+ }
+
+ # views
+ assert set(executor.views()) >= {("public", "d")}
+
+ assert set(executor.view_columns()) >= {
+ ("public", "d", "e", "integer", False, None)
+ }
+
+
+@dbtest
+def test_foreign_key_query(executor):
+ run(executor, "create schema schema1")
+ run(executor, "create schema schema2")
+ run(executor, "create table schema1.parent(parentid int PRIMARY KEY)")
+ run(
+ executor,
+ "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
+ )
+
+ assert set(executor.foreignkeys()) >= {
+ ("schema1", "parent", "parentid", "schema2", "child", "motherid")
+ }
+
+
+@dbtest
+def test_functions_query(executor):
+ run(
+ executor,
+ """create function func1() returns int
+ language sql as $$select 1$$""",
+ )
+ run(executor, "create schema schema1")
+ run(
+ executor,
+ """create function schema1.func2() returns int
+ language sql as $$select 2$$""",
+ )
+
+ run(
+ executor,
+ """create function func3()
+ returns table(x int, y int) language sql
+ as $$select 1, 2 from generate_series(1,5)$$;""",
+ )
+
+ run(
+ executor,
+ """create function func4(x int) returns setof int language sql
+ as $$select generate_series(1,5)$$;""",
+ )
+
+ funcs = set(executor.functions())
+ assert funcs >= {
+ function_meta_data(func_name="func1", return_type="integer"),
+ function_meta_data(
+ func_name="func3",
+ arg_names=["x", "y"],
+ arg_types=["integer", "integer"],
+ arg_modes=["t", "t"],
+ return_type="record",
+ is_set_returning=True,
+ ),
+ function_meta_data(
+ schema_name="public",
+ func_name="func4",
+ arg_names=("x",),
+ arg_types=("integer",),
+ return_type="integer",
+ is_set_returning=True,
+ ),
+ function_meta_data(
+ schema_name="schema1", func_name="func2", return_type="integer"
+ ),
+ }
+
+
+@dbtest
+def test_datatypes_query(executor):
+ run(executor, "create type foo AS (a int, b text)")
+
+ types = list(executor.datatypes())
+ assert types == [("public", "foo")]
+
+
+@dbtest
+def test_database_list(executor):
+ databases = executor.databases()
+ assert "_test_db" in databases
+
+
+@dbtest
+def test_invalid_syntax(executor, exception_formatter):
+ result = run(executor, "invalid syntax!", exception_formatter=exception_formatter)
+ assert 'syntax error at or near "invalid"' in result[0]
+
+
+@dbtest
+def test_invalid_column_name(executor, exception_formatter):
+ result = run(
+ executor, "select invalid command", exception_formatter=exception_formatter
+ )
+ assert 'column "invalid" does not exist' in result[0]
+
+
+@pytest.fixture(params=[True, False])
+def expanded(request):
+ return request.param
+
+
+@dbtest
+def test_unicode_support_in_output(executor, expanded):
+ run(executor, "create table unicodechars(t text)")
+ run(executor, "insert into unicodechars (t) values ('é')")
+
+ # See issue #24, this raises an exception without proper handling
+ assert "é" in run(
+ executor, "select * from unicodechars", join=True, expanded=expanded
+ )
+
+
+@dbtest
+def test_not_is_special(executor, pgspecial):
+ """is_special is set to false for database queries."""
+ query = "select 1"
+ result = list(executor.run(query, pgspecial=pgspecial))
+ success, is_special = result[0][5:]
+ assert success == True
+ assert is_special == False
+
+
+@dbtest
+def test_execute_from_file_no_arg(executor, pgspecial):
+ r"""\i without a filename returns an error."""
+ result = list(executor.run(r"\i", pgspecial=pgspecial))
+ status, sql, success, is_special = result[0][3:]
+ assert "missing required argument" in status
+ assert success == False
+ assert is_special == True
+
+
+@dbtest
+@patch("pgcli.main.os")
+def test_execute_from_file_io_error(os, executor, pgspecial):
+ r"""\i with an os_error returns an error."""
+ # Inject an OSError.
+ os.path.expanduser.side_effect = OSError("test")
+
+ # Check the result.
+ result = list(executor.run(r"\i test", pgspecial=pgspecial))
+ status, sql, success, is_special = result[0][3:]
+ assert status == "test"
+ assert success == False
+ assert is_special == True
+
+
+@dbtest
+def test_execute_from_commented_file_that_executes_another_file(
+ executor, pgspecial, tmpdir
+):
+ # https://github.com/dbcli/pgcli/issues/1336
+ sqlfile1 = tmpdir.join("test01.sql")
+ sqlfile1.write("-- asdf \n\\h")
+ sqlfile2 = tmpdir.join("test00.sql")
+ sqlfile2.write("--An useless comment;\nselect now();\n-- another useless comment")
+
+ rcfile = str(tmpdir.join("rcfile"))
+ print(rcfile)
+ cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
+ assert cli != None
+ statement = "--comment\n\\h"
+ result = run(executor, statement, pgspecial=cli.pgspecial)
+ assert result != None
+ assert result[0].find("ALTER TABLE")
+
+
+@dbtest
+def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
+ # just some base cases that should work also
+ statement = "--comment\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ statement = "/*comment*/\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ # https://github.com/dbcli/pgcli/issues/1362
+ statement = "--comment\n\\h"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = "--comment1\n--comment2\n\\h"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = "/*comment*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = """/*comment1
+ comment2*/
+ \h"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = """/*comment1
+ comment2*/
+ /*comment 3
+ comment4*/
+ \\h"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = " /*comment*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = "/*comment\ncomment line2*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = " /*comment\ncomment line2*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = """\\h /*comment4 */"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ print(result)
+ assert result != None
+ assert result[0].find("No help") >= 0
+
+ # TODO: we probably don't want to do this but sqlparse is not parsing things well
+ # we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/
+ # style comments after command
+
+ statement = """/*comment1*/
+ \h
+ /*comment4 */"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[0].find("No help") >= 0
+
+ # TODO: same for this one
+ statement = """/*comment1
+ comment3
+ comment2*/
+ \\h
+ /*comment4
+ comment5
+ comment6*/"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[0].find("No help") >= 0
+
+
+@dbtest
+def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
+ # https://github.com/dbcli/pgcli/issues/1403
+
+ # just some base cases that should work also
+ statement = "--comment\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ statement = "/*comment*/\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ # this simulates the original error (1403) without having to add/drop tables
+ # since it was just an error on reading input files and not the actual
+ # command itself
+
+ # test that the statement works
+ statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # test the statement with a \n in the middle
+ statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # test the statement with a newline in the middle
+ statement = """VALUES (1, 'one'),
+ (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # now add a single comment line
+ statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # doing without special char \n
+ statement = """--comment
+ VALUES (1,'one'),
+ (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # two comment lines
+ statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # doing without special char \n
+ statement = """--comment
+ --comment2
+ VALUES (1,'one'), (2, 'two'), (3, 'three');
+ """
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # multiline comment + newline in middle of the statement
+ statement = """/*comment
+comment2
+comment3*/
+VALUES (1,'one'),
+(2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # multiline comment + newline in middle of the statement
+ # + comments after the statement
+ statement = """/*comment
+comment2
+comment3*/
+VALUES (1,'one'),
+(2, 'two'), (3, 'three');
+--comment4
+--comment5"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+
+@dbtest
+def test_multiple_queries_same_line(executor):
+ result = run(executor, "select 'foo'; select 'bar'")
+ assert len(result) == 12 # 2 * (output+status) * 3 lines
+ assert "foo" in result[3]
+ assert "bar" in result[9]
+
+
+@dbtest
+def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
+ result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial)
+ assert len(result) == 11 # 2 * (output+status) * 3 lines
+ assert "foo" in result[3]
+ # This is a lame check. :(
+ assert "Schema" in result[7]
+
+
+@dbtest
+def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter):
+ result = run(
+ executor,
+ "select 'fooé'; invalid syntax é",
+ exception_formatter=exception_formatter,
+ )
+ assert "fooé" in result[3]
+ assert 'syntax error at or near "invalid"' in result[-1]
+
+
+@pytest.fixture
+def pgspecial():
+ return PGCli().pgspecial
+
+
+@dbtest
+def test_special_command_help(executor, pgspecial):
+ result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|")
+ assert "Command" in result[1]
+ assert "Description" in result[2]
+
+
+@dbtest
+def test_bytea_field_support_in_output(executor):
+ run(executor, "create table binarydata(c bytea)")
+ run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))")
+
+ assert "\\xdeadbeef" in run(executor, "select * from binarydata", join=True)
+
+
+@dbtest
+def test_unicode_support_in_unknown_type(executor):
+ assert "日本語" in run(executor, "SELECT '日本語' AS japanese;", join=True)
+
+
+@dbtest
+def test_unicode_support_in_enum_type(executor):
+ run(executor, "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')")
+ run(executor, "CREATE TABLE person (name TEXT, current_mood mood)")
+ run(executor, "INSERT INTO person VALUES ('Moe', '日本語')")
+ assert "日本語" in run(executor, "SELECT * FROM person", join=True)
+
+
+@requires_json
+def test_json_renders_without_u_prefix(executor, expanded):
+ run(executor, "create table jsontest(d json)")
+ run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""")
+ result = run(
+ executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded
+ )
+
+ assert '{"name": "Éowyn"}' in result
+
+
+@requires_jsonb
+def test_jsonb_renders_without_u_prefix(executor, expanded):
+ run(executor, "create table jsonbtest(d jsonb)")
+ run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""")
+ result = run(
+ executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded
+ )
+
+ assert '{"name": "Éowyn"}' in result
+
+
+@dbtest
+def test_date_time_types(executor):
+ run(executor, "SET TIME ZONE UTC")
+ assert (
+ run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3]
+ == "| 00:00:00 |"
+ )
+ assert (
+ run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split(
+ "\n"
+ )[3]
+ == "| 00:00:00+14:59 |"
+ )
+ assert (
+ run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[
+ 3
+ ]
+ == "| 4713-01-01 BC |"
+ )
+ assert (
+ run(
+ executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True
+ ).split("\n")[3]
+ == "| 4713-01-01 00:00:00 BC |"
+ )
+ assert (
+ run(
+ executor,
+ "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))",
+ join=True,
+ ).split("\n")[3]
+ == "| 4713-01-01 00:00:00+00 BC |"
+ )
+ assert (
+ run(
+ executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True
+ ).split("\n")[3]
+ == "| -123456789 days, 12:23:56 |"
+ )
+
+
+@dbtest
+@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
+def test_large_numbers_render_directly(executor, value):
+ run(executor, "create table numbertest(a numeric)")
+ run(executor, f"insert into numbertest (a) values ({value})")
+ assert value in run(executor, "select * from numbertest", join=True)
+
+
+@dbtest
+@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"])
+@pytest.mark.parametrize("verbose", ["", "+"])
+@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"])
+def test_describe_special(executor, command, verbose, pattern, pgspecial):
+ # We don't have any tests for the output of any of the special commands,
+ # but we can at least make sure they run without error
+ sql = r"\{command}{verbose} {pattern}".format(**locals())
+ list(executor.run(sql, pgspecial=pgspecial))
+
+
+@dbtest
+@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"])
+def test_raises_with_no_formatter(executor, sql):
+ with pytest.raises(psycopg.ProgrammingError):
+ list(executor.run(sql))
+
+
+@dbtest
+def test_on_error_resume(executor, exception_formatter):
+ sql = "select 1; error; select 1;"
+ result = list(
+ executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter)
+ )
+ assert len(result) == 3
+
+
+@dbtest
+def test_on_error_stop(executor, exception_formatter):
+ sql = "select 1; error; select 1;"
+ result = list(
+ executor.run(
+ sql, on_error_resume=False, exception_formatter=exception_formatter
+ )
+ )
+ assert len(result) == 2
+
+
+# @dbtest
+# def test_unicode_notices(executor):
+# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;"
+# result = list(executor.run(sql))
+# assert result[0][0] == u'NOTICE: 有人更改\n'
+
+
+@dbtest
+def test_nonexistent_function_definition(executor):
+ with pytest.raises(RuntimeError):
+ result = executor.view_definition("there_is_no_such_function")
+
+
+@dbtest
+def test_function_definition(executor):
+ run(
+ executor,
+ """
+ CREATE OR REPLACE FUNCTION public.the_number_three()
+ RETURNS int
+ LANGUAGE sql
+ AS $function$
+ select 3;
+ $function$
+ """,
+ )
+ result = executor.function_definition("the_number_three")
+
+
+@dbtest
+def test_view_definition(executor):
+ run(executor, "create table tbl1 (a text, b numeric)")
+ run(executor, "create view vw1 AS SELECT * FROM tbl1")
+ run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1")
+ result = executor.view_definition("vw1")
+ assert 'VIEW "public"."vw1" AS' in result
+ assert "FROM tbl1" in result
+ # import pytest; pytest.set_trace()
+ result = executor.view_definition("mvw1")
+ assert "MATERIALIZED VIEW" in result
+
+
+@dbtest
+def test_nonexistent_view_definition(executor):
+ with pytest.raises(RuntimeError):
+ result = executor.view_definition("there_is_no_such_view")
+ with pytest.raises(RuntimeError):
+ result = executor.view_definition("mvw1")
+
+
+@dbtest
+def test_short_host(executor):
+ with patch.object(executor, "host", "localhost"):
+ assert executor.short_host == "localhost"
+ with patch.object(executor, "host", "localhost.example.org"):
+ assert executor.short_host == "localhost"
+ with patch.object(
+ executor, "host", "localhost1.example.org,localhost2.example.org"
+ ):
+ assert executor.short_host == "localhost1"
+
+
+class VirtualCursor:
+ """Mock a cursor to virtual database like pgbouncer."""
+
+ def __init__(self):
+ self.protocol_error = False
+ self.protocol_message = ""
+ self.description = None
+ self.status = None
+ self.statusmessage = "Error"
+
+ def execute(self, *args, **kwargs):
+ self.protocol_error = True
+ self.protocol_message = "Command not supported"
+
+
+@dbtest
+def test_exit_without_active_connection(executor):
+ quit_handler = MagicMock()
+ pgspecial = PGSpecial()
+ pgspecial.register(
+ quit_handler,
+ "\\q",
+ "\\q",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+ aliases=(":q",),
+ )
+
+ with patch.object(
+ executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!")
+ ):
+ # we should be able to quit the app, even without active connection
+ run(executor, "\\q", pgspecial=pgspecial)
+ quit_handler.assert_called_once()
+
+ # an exception should be raised when running a query without active connection
+ with pytest.raises(psycopg.InterfaceError):
+ run(executor, "select 1", pgspecial=pgspecial)
+
+
+@dbtest
+def test_virtual_database(executor):
+ virtual_connection = MagicMock()
+ virtual_connection.cursor.return_value = VirtualCursor()
+ with patch.object(executor, "conn", virtual_connection):
+ result = run(executor, "select 1")
+ assert "Command not supported" in result
diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py
new file mode 100644
index 0000000..cd99e32
--- /dev/null
+++ b/tests/test_pgspecial.py
@@ -0,0 +1,78 @@
+import pytest
+from pgcli.packages.sqlcompletion import (
+ suggest_type,
+ Special,
+ Database,
+ Schema,
+ Table,
+ View,
+ Function,
+ Datatype,
+)
+
+
+def test_slash_suggests_special():
+ suggestions = suggest_type("\\", "\\")
+ assert set(suggestions) == {Special()}
+
+
+def test_slash_d_suggests_special():
+ suggestions = suggest_type("\\d", "\\d")
+ assert set(suggestions) == {Special()}
+
+
+def test_dn_suggests_schemata():
+ suggestions = suggest_type("\\dn ", "\\dn ")
+ assert suggestions == (Schema(),)
+
+ suggestions = suggest_type("\\dn xxx", "\\dn xxx")
+ assert suggestions == (Schema(),)
+
+
+def test_d_suggests_tables_views_and_schemas():
+ suggestions = suggest_type(r"\d ", r"\d ")
+ assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
+
+ suggestions = suggest_type(r"\d xxx", r"\d xxx")
+ assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
+
+
+def test_d_dot_suggests_schema_qualified_tables_or_views():
+ suggestions = suggest_type(r"\d myschema.", r"\d myschema.")
+ assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
+
+ suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx")
+ assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
+
+
+def test_df_suggests_schema_or_function():
+ suggestions = suggest_type("\\df xxx", "\\df xxx")
+ assert set(suggestions) == {Function(schema=None, usage="special"), Schema()}
+
+ suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
+ assert suggestions == (Function(schema="myschema", usage="special"),)
+
+
+def test_leading_whitespace_ok():
+ cmd = "\\dn "
+ whitespace = " "
+ suggestions = suggest_type(whitespace + cmd, whitespace + cmd)
+ assert suggestions == suggest_type(cmd, cmd)
+
+
+def test_dT_suggests_schema_or_datatypes():
+ text = "\\dT "
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == {Schema(), Datatype(schema=None)}
+
+
+def test_schema_qualified_dT_suggests_datatypes():
+ text = "\\dT foo."
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Datatype(schema="foo"),)
+
+
+@pytest.mark.parametrize("command", ["\\c ", "\\connect "])
+def test_c_suggests_databases(command):
+ suggestions = suggest_type(command, command)
+ assert suggestions == (Database(),)
diff --git a/tests/test_prioritization.py b/tests/test_prioritization.py
new file mode 100644
index 0000000..f5b6700
--- /dev/null
+++ b/tests/test_prioritization.py
@@ -0,0 +1,20 @@
+from pgcli.packages.prioritization import PrevalenceCounter
+
+
+def test_prevalence_counter():
+ counter = PrevalenceCounter()
+ sql = """SELECT * FROM foo WHERE bar GROUP BY baz;
+ select * from foo;
+ SELECT * FROM foo WHERE bar GROUP
+ BY baz"""
+ counter.update(sql)
+
+ keywords = ["SELECT", "FROM", "GROUP BY"]
+ expected = [3, 3, 2]
+ kw_counts = [counter.keyword_count(x) for x in keywords]
+ assert kw_counts == expected
+ assert counter.keyword_count("NOSUCHKEYWORD") == 0
+
+ names = ["foo", "bar", "baz"]
+ name_counts = [counter.name_count(x) for x in names]
+ assert name_counts == [3, 2, 2]
diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py
new file mode 100644
index 0000000..91abe37
--- /dev/null
+++ b/tests/test_prompt_utils.py
@@ -0,0 +1,17 @@
+import click
+
+from pgcli.packages.prompt_utils import confirm_destructive_query
+
+
+def test_confirm_destructive_query_notty():
+ stdin = click.get_text_stream("stdin")
+ if not stdin.isatty():
+ sql = "drop database foo;"
+ assert confirm_destructive_query(sql, [], None) is None
+
+
+def test_confirm_destructive_query_with_alias():
+ stdin = click.get_text_stream("stdin")
+ if not stdin.isatty():
+ sql = "drop database foo;"
+ assert confirm_destructive_query(sql, ["drop"], "test") is None
diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py
new file mode 100644
index 0000000..da916b4
--- /dev/null
+++ b/tests/test_rowlimit.py
@@ -0,0 +1,79 @@
+import pytest
+from unittest.mock import Mock
+
+from pgcli.main import PGCli
+
+
+# We need this fixtures because we need PGCli object to be created
+# after test collection so it has config loaded from temp directory
+
+
+@pytest.fixture(scope="module")
+def default_pgcli_obj():
+ return PGCli()
+
+
+@pytest.fixture(scope="module")
+def DEFAULT(default_pgcli_obj):
+ return default_pgcli_obj.row_limit
+
+
+@pytest.fixture(scope="module")
+def LIMIT(DEFAULT):
+ return DEFAULT + 1000
+
+
+@pytest.fixture(scope="module")
+def over_default(DEFAULT):
+ over_default_cursor = Mock()
+ over_default_cursor.configure_mock(rowcount=DEFAULT + 10)
+ return over_default_cursor
+
+
+@pytest.fixture(scope="module")
+def over_limit(LIMIT):
+ over_limit_cursor = Mock()
+ over_limit_cursor.configure_mock(rowcount=LIMIT + 10)
+ return over_limit_cursor
+
+
+@pytest.fixture(scope="module")
+def low_count():
+ low_count_cursor = Mock()
+ low_count_cursor.configure_mock(rowcount=1)
+ return low_count_cursor
+
+
+def test_row_limit_with_LIMIT_clause(LIMIT, over_limit):
+ cli = PGCli(row_limit=LIMIT)
+ stmt = "SELECT * FROM students LIMIT 1000"
+
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+ cli = PGCli(row_limit=0)
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+
+def test_row_limit_without_LIMIT_clause(LIMIT, over_limit):
+ cli = PGCli(row_limit=LIMIT)
+ stmt = "SELECT * FROM students"
+
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is True
+
+ cli = PGCli(row_limit=0)
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+
+def test_row_limit_on_non_select(over_limit):
+ cli = PGCli()
+ stmt = "UPDATE students SET name='Boby'"
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+ cli = PGCli(row_limit=0)
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
new file mode 100644
index 0000000..5c9c9af
--- /dev/null
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -0,0 +1,757 @@
+import itertools
+from metadata import (
+ MetaData,
+ alias,
+ name_join,
+ fk_join,
+ join,
+ schema,
+ table,
+ function,
+ wildcard_expansion,
+ column,
+ get_result,
+ result_set,
+ qual,
+ no_qual,
+ parametrize,
+)
+from utils import completions_to_set
+
+metadata = {
+ "tables": {
+ "public": {
+ "users": ["id", "email", "first_name", "last_name"],
+ "orders": ["id", "ordered_date", "status", "datestamp"],
+ "select": ["id", "localtime", "ABC"],
+ },
+ "custom": {
+ "users": ["id", "phone_number"],
+ "Users": ["userid", "username"],
+ "products": ["id", "product_name", "price"],
+ "shipments": ["id", "address", "user_id"],
+ },
+ "Custom": {"projects": ["projectid", "name"]},
+ "blog": {
+ "entries": ["entryid", "entrytitle", "entrytext"],
+ "tags": ["tagid", "name"],
+ "entrytags": ["entryid", "tagid"],
+ "entacclog": ["entryid", "username", "datestamp"],
+ },
+ },
+ "functions": {
+ "public": [
+ ["func1", [], [], [], "", False, False, False, False],
+ ["func2", [], [], [], "", False, False, False, False],
+ ],
+ "custom": [
+ ["func3", [], [], [], "", False, False, False, False],
+ [
+ "set_returning_func",
+ ["x"],
+ ["integer"],
+ ["o"],
+ "integer",
+ False,
+ False,
+ True,
+ False,
+ ],
+ ],
+ "Custom": [["func4", [], [], [], "", False, False, False, False]],
+ "blog": [
+ [
+ "extract_entry_symbols",
+ ["_entryid", "symbol"],
+ ["integer", "text"],
+ ["i", "o"],
+ "",
+ False,
+ False,
+ True,
+ False,
+ ],
+ [
+ "enter_entry",
+ ["_title", "_text", "entryid"],
+ ["text", "text", "integer"],
+ ["i", "i", "o"],
+ "",
+ False,
+ False,
+ False,
+ False,
+ ],
+ ],
+ },
+ "datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]},
+ "foreignkeys": {
+ "custom": [("public", "users", "id", "custom", "shipments", "user_id")],
+ "blog": [
+ ("blog", "entries", "entryid", "blog", "entacclog", "entryid"),
+ ("blog", "entries", "entryid", "blog", "entrytags", "entryid"),
+ ("blog", "tags", "tagid", "blog", "entrytags", "tagid"),
+ ],
+ },
+ "defaults": {
+ "public": {
+ ("orders", "id"): "nextval('orders_id_seq'::regclass)",
+ ("orders", "datestamp"): "now()",
+ ("orders", "status"): "'PENDING'::text",
+ }
+ },
+}
+
+testdata = MetaData(metadata)
+cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')]
+casing = (
+ "SELECT",
+ "Orders",
+ "User_Emails",
+ "CUSTOM",
+ "Func1",
+ "Entries",
+ "Tags",
+ "EntryTags",
+ "EntAccLog",
+ "EntryID",
+ "EntryTitle",
+ "EntryText",
+)
+completers = testdata.get_completers(casing)
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize("table", ["users", '"users"'])
+def test_suggested_column_names_from_shadowed_visible_table(completer, table):
+ result = get_result(completer, "SELECT FROM " + table, len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "SELECT from custom.users",
+ "WITH users as (SELECT 1 AS foo) SELECT from custom.users",
+ ],
+)
+def test_suggested_column_names_from_qualified_shadowed_table(completer, text):
+ result = get_result(completer, text, position=text.find(" ") + 1)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"])
+def test_suggested_column_names_from_cte(completer, text):
+ result = completions_to_set(get_result(completer, text, text.find(" ") + 1))
+ assert result == completions_to_set(
+ [column("foo")] + testdata.functions_and_keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users JOIN custom.shipments ON ",
+ """SELECT *
+ FROM public.users
+ JOIN custom.shipments ON """,
+ ],
+)
+def test_suggested_join_conditions(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ alias("users"),
+ alias("shipments"),
+ name_join("shipments.id = users.id"),
+ fk_join("shipments.user_id = users.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+@parametrize(
+ ("query", "tbl"),
+ itertools.product(
+ (
+ "SELECT * FROM public.{0} RIGHT OUTER JOIN ",
+ """SELECT *
+ FROM {0}
+ JOIN """,
+ ),
+ ("users", '"users"', "Users"),
+ ),
+)
+def test_suggested_joins(completer, query, tbl):
+ result = get_result(completer, query.format(tbl))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+def test_suggested_column_names_from_schema_qualifed_table(completer):
+ result = get_result(completer, "SELECT from custom.products", len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("products", "custom")
+ )
+
+
+@parametrize(
+ "text",
+ [
+ "INSERT INTO orders(",
+ "INSERT INTO orders (",
+ "INSERT INTO public.orders(",
+ "INSERT INTO public.orders (",
+ ],
+)
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggested_columns_with_insert(completer, text):
+ assert completions_to_set(get_result(completer, text)) == completions_to_set(
+ testdata.columns("orders")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+def test_suggested_column_names_in_function(completer):
+ result = get_result(
+ completer, "SELECT MAX( from custom.products", len("SELECT MAX(")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("products", "custom")
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ ["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'],
+)
+@parametrize("use_leading_double_quote", [False, True])
+def test_suggested_table_names_with_schema_dot(
+ completer, text, use_leading_double_quote
+):
+ if use_leading_double_quote:
+ text += '"'
+ start_position = -1
+ else:
+ start_position = 0
+
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.from_clause_items("custom", start_position)
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize("text", ['SELECT * FROM "Custom".'])
+@parametrize("use_leading_double_quote", [False, True])
+def test_suggested_table_names_with_schema_dot2(
+ completer, text, use_leading_double_quote
+):
+ if use_leading_double_quote:
+ text += '"'
+ start_position = -1
+ else:
+ start_position = 0
+
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.from_clause_items("Custom", start_position)
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggested_column_names_with_qualified_alias(completer):
+ result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p."))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("products", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+def test_suggested_multiple_column_names(completer):
+ result = get_result(
+ completer, "SELECT id, from custom.products", len("SELECT id, ")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("products", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggested_multiple_column_names_with_alias(completer):
+ result = get_result(
+ completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("products", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ",
+ "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id",
+ ],
+)
+def test_suggestions_after_on(completer, text):
+ position = len(
+ "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON "
+ )
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ alias("x"),
+ alias("y"),
+ name_join("y.price = x.price"),
+ name_join("y.product_name = x.product_name"),
+ name_join("y.id = x.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers())
+def test_suggested_aliases_after_on_right_side(completer):
+ text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = "
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")])
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_table_names_after_from(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_schema_qualified_function_name(completer):
+ text = "SELECT custom.func"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("func3()", -len("func")),
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_schema_qualified_function_name_after_from(completer):
+ text = "SELECT * FROM custom.set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_unqualified_function_name_not_returned(completer):
+ text = "SELECT * FROM set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set([])
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_unqualified_function_name_in_search_path(completer):
+ completer.search_path = ["public", "custom"]
+ text = "SELECT * FROM set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT 1::custom.",
+ "CREATE TABLE foo (bar custom.",
+ "CREATE FUNCTION foo (bar INT, baz custom.",
+ "ALTER TABLE foo ALTER COLUMN bar TYPE custom.",
+ ],
+)
+def test_schema_qualified_type_name(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(testdata.types("custom"))
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggest_columns_from_aliased_set_returning_function(completer):
+ result = get_result(
+ completer, "select f. from custom.set_returning_func() f", len("select f.")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("set_returning_func", "custom", "functions")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM custom.set_returning_func()",
+ "SELECT * FROM Custom.set_returning_func()",
+ "SELECT * FROM Custom.Set_Returning_Func()",
+ ],
+)
+def test_wildcard_column_expansion_with_function(completer, text):
+ position = len("SELECT *")
+
+ completions = get_result(completer, text, position)
+
+ col_list = "x"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_wildcard_column_expansion_with_alias_qualifier(completer):
+ text = "SELECT p.* FROM custom.products p"
+ position = len("SELECT p.*")
+
+ completions = get_result(completer, text, position)
+
+ col_list = "id, p.product_name, p.price"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ """
+ SELECT count(1) FROM users;
+ CREATE FUNCTION foo(custom.products _products) returns custom.shipments
+ LANGUAGE SQL
+ AS $foo$
+ SELECT 1 FROM custom.shipments;
+ INSERT INTO public.orders(*) values(-1, now(), 'preliminary');
+ SELECT 2 FROM custom.users;
+ $foo$;
+ SELECT count(1) FROM custom.shipments;
+ """,
+ "INSERT INTO public.orders(*",
+ "INSERT INTO public.Orders(*",
+ "INSERT INTO public.orders (*",
+ "INSERT INTO public.Orders (*",
+ "INSERT INTO orders(*",
+ "INSERT INTO Orders(*",
+ "INSERT INTO orders (*",
+ "INSERT INTO Orders (*",
+ "INSERT INTO public.orders(*)",
+ "INSERT INTO public.Orders(*)",
+ "INSERT INTO public.orders (*)",
+ "INSERT INTO public.Orders (*)",
+ "INSERT INTO orders(*)",
+ "INSERT INTO Orders(*)",
+ "INSERT INTO orders (*)",
+ "INSERT INTO Orders (*)",
+ ],
+)
+def test_wildcard_column_expansion_with_insert(completer, text):
+ position = text.index("*") + 1
+ completions = get_result(completer, text, position)
+
+ expected = [wildcard_expansion("ordered_date, status")]
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_wildcard_column_expansion_with_table_qualifier(completer):
+ text = 'SELECT "select".* FROM public."select"'
+ position = len('SELECT "select".*')
+
+ completions = get_result(completer, text, position)
+
+ col_list = 'id, "select"."localtime", "select"."ABC"'
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=qual))
+def test_wildcard_column_expansion_with_two_tables(completer):
+ text = 'SELECT * FROM public."select" JOIN custom.users ON true'
+ position = len("SELECT *")
+
+ completions = get_result(completer, text, position)
+
+ cols = (
+ '"select".id, "select"."localtime", "select"."ABC", '
+ "users.id, users.phone_number"
+ )
+ expected = [wildcard_expansion(cols)]
+ assert completions == expected
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_wildcard_column_expansion_with_two_tables_and_parent(completer):
+ text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true'
+ position = len('SELECT "select".*')
+
+ completions = get_result(completer, text, position)
+
+ col_list = 'id, "select"."localtime", "select"."ABC"'
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT U. FROM custom.Users U",
+ "SELECT U. FROM custom.USERS U",
+ "SELECT U. FROM custom.users U",
+ 'SELECT U. FROM "custom".Users U',
+ 'SELECT U. FROM "custom".USERS U',
+ 'SELECT U. FROM "custom".users U',
+ ],
+)
+def test_suggest_columns_from_unquoted_table(completer, text):
+ position = len("SELECT U.")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("users", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U']
+)
+def test_suggest_columns_from_quoted_table(completer, text):
+ position = len("SELECT U.")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("Users", "custom")
+ )
+
+
+texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "]
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+@parametrize("text", texts)
+def test_schema_or_visible_table_completion(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+
+
+@parametrize("completer", completers(aliasing=True, casing=False, filtr=True))
+@parametrize("text", texts)
+def test_table_aliases(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas()
+ + [
+ table("users u"),
+ table("orders o" if text == "SELECT * FROM " else "orders o2"),
+ table('"select" s'),
+ function("func1() f"),
+ function("func2() f"),
+ ]
+ )
+
+
+@parametrize("completer", completers(aliasing=True, casing=True, filtr=True))
+@parametrize("text", texts)
+def test_aliases_with_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ cased_schemas
+ + [
+ table("users u"),
+ table("Orders O" if text == "SELECT * FROM " else "Orders O2"),
+ table('"select" s'),
+ function("Func1() F"),
+ function("func2() f"),
+ ]
+ )
+
+
+@parametrize("completer", completers(aliasing=False, casing=True, filtr=True))
+@parametrize("text", texts)
+def test_table_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ cased_schemas
+ + [
+ table("users"),
+ table("Orders"),
+ table('"select"'),
+ function("Func1()"),
+ function("func2()"),
+ ]
+ )
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_alias_search_without_aliases2(completer):
+ text = "SELECT * FROM blog.et"
+ result = get_result(completer, text)
+ assert result[0] == table("EntryTags", -2)
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_alias_search_without_aliases1(completer):
+ text = "SELECT * FROM blog.e"
+ result = get_result(completer, text)
+ assert result[0] == table("Entries", -1)
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_alias_search_with_aliases2(completer):
+ text = "SELECT * FROM blog.et"
+ result = get_result(completer, text)
+ assert result[0] == table("EntryTags ET", -2)
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_alias_search_with_aliases1(completer):
+ text = "SELECT * FROM blog.e"
+ result = get_result(completer, text)
+ assert result[0] == table("Entries E", -1)
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_join_alias_search_with_aliases1(completer):
+ text = "SELECT * FROM blog.Entries E JOIN blog.e"
+ result = get_result(completer, text)
+ assert result[:2] == [
+ table("Entries E2", -1),
+ join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1),
+ ]
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_join_alias_search_without_aliases1(completer):
+ text = "SELECT * FROM blog.Entries JOIN blog.e"
+ result = get_result(completer, text)
+ assert result[:2] == [
+ table("Entries", -1),
+ join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1),
+ ]
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_join_alias_search_with_aliases2(completer):
+ text = "SELECT * FROM blog.Entries E JOIN blog.et"
+ result = get_result(completer, text)
+ assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2)
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_join_alias_search_without_aliases2(completer):
+ text = "SELECT * FROM blog.Entries JOIN blog.et"
+ result = get_result(completer, text)
+ assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2)
+
+
+@parametrize("completer", completers())
+def test_function_alias_search_without_aliases(completer):
+ text = "SELECT blog.ees"
+ result = get_result(completer, text)
+ first = result[0]
+ assert first.start_position == -3
+ assert first.text == "extract_entry_symbols()"
+ assert first.display_text == "extract_entry_symbols(_entryid)"
+
+
+@parametrize("completer", completers())
+def test_function_alias_search_with_aliases(completer):
+ text = "SELECT blog.ee"
+ result = get_result(completer, text)
+ first = result[0]
+ assert first.start_position == -2
+ assert first.text == "enter_entry(_title := , _text := )"
+ assert first.display_text == "enter_entry(_title, _text)"
+
+
+@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual))
+def test_column_alias_search(completer):
+ result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et"))
+ cols = ("EntryText", "EntryTitle", "EntryID")
+ assert result[:3] == [column(c, -2) for c in cols]
+
+
+@parametrize("completer", completers(casing=True))
+def test_column_alias_search_qualified(completer):
+ result = get_result(
+ completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei")
+ )
+ cols = ("EntryID", "EntryTitle")
+ assert result[:3] == [column(c, -2) for c in cols]
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
+def test_schema_object_order(completer):
+ result = get_result(completer, "SELECT * FROM u")
+ assert result[:3] == [
+ table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users")
+ ]
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
+def test_all_schema_objects(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) >= completions_to_set(
+ [table(x) for x in ("orders", '"select"', "custom.shipments")]
+ + [function(x + "()") for x in ("func2",)]
+ )
+
+
+@parametrize("completer", completers(filtr=False, aliasing=False, casing=True))
+def test_all_schema_objects_with_casing(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) >= completions_to_set(
+ [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")]
+ + [function(x + "()") for x in ("func2",)]
+ )
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
+def test_all_schema_objects_with_aliases(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) >= completions_to_set(
+ [table(x) for x in ("orders o", '"select" s', "custom.shipments s")]
+ + [function(x) for x in ("func2() f",)]
+ )
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
+def test_set_schema(completer):
+ text = "SET SCHEMA "
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")]
+ )
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
new file mode 100644
index 0000000..db1fe0a
--- /dev/null
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -0,0 +1,1112 @@
+from metadata import (
+ MetaData,
+ alias,
+ name_join,
+ fk_join,
+ join,
+ keyword,
+ schema,
+ table,
+ view,
+ function,
+ column,
+ wildcard_expansion,
+ get_result,
+ result_set,
+ qual,
+ no_qual,
+ parametrize,
+)
+from prompt_toolkit.completion import Completion
+from utils import completions_to_set
+
+
+metadata = {
+ "tables": {
+ "users": ["id", "parentid", "email", "first_name", "last_name"],
+ "Users": ["userid", "username"],
+ "orders": ["id", "ordered_date", "status", "email"],
+ "select": ["id", "insert", "ABC"],
+ },
+ "views": {"user_emails": ["id", "email"], "functions": ["function"]},
+ "functions": [
+ ["custom_fun", [], [], [], "", False, False, False, False],
+ ["_custom_fun", [], [], [], "", False, False, False, False],
+ ["custom_func1", [], [], [], "", False, False, False, False],
+ ["custom_func2", [], [], [], "", False, False, False, False],
+ [
+ "set_returning_func",
+ ["x", "y"],
+ ["integer", "integer"],
+ ["b", "b"],
+ "",
+ False,
+ False,
+ True,
+ False,
+ ],
+ ],
+ "datatypes": ["custom_type1", "custom_type2"],
+ "foreignkeys": [
+ ("public", "users", "id", "public", "users", "parentid"),
+ ("public", "users", "id", "public", "Users", "userid"),
+ ],
+}
+
+metadata = {k: {"public": v} for k, v in metadata.items()}
+
+testdata = MetaData(metadata)
+
+cased_users_col_names = ["ID", "PARENTID", "Email", "First_Name", "last_name"]
+cased_users2_col_names = ["UserID", "UserName"]
+cased_func_names = [
+ "Custom_Fun",
+ "_custom_fun",
+ "Custom_Func1",
+ "custom_func2",
+ "set_returning_func",
+]
+cased_tbls = ["Users", "Orders"]
+cased_views = ["User_Emails", "Functions"]
+casing = (
+ ["SELECT", "PUBLIC"]
+ + cased_func_names
+ + cased_tbls
+ + cased_views
+ + cased_users_col_names
+ + cased_users2_col_names
+)
+# Lists for use in assertions
+cased_funcs = [
+ function(f)
+ for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()")
+] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")]
+cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])]
+cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls
+cased_users_cols = [column(c) for c in cased_users_col_names]
+aliased_rels = (
+ [table(t) for t in ("users u", '"Users" U', "orders o", '"select" s')]
+ + [view("user_emails ue"), view("functions f")]
+ + [
+ function(f)
+ for f in (
+ "_custom_fun() cf",
+ "custom_fun() cf",
+ "custom_func1() cf",
+ "custom_func2() cf",
+ )
+ ]
+ + [
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ )
+ ]
+)
+cased_aliased_rels = (
+ [table(t) for t in ("Users U", '"Users" U', "Orders O", '"select" s')]
+ + [view("User_Emails UE"), view("Functions F")]
+ + [
+ function(f)
+ for f in (
+ "_custom_fun() cf",
+ "Custom_Fun() CF",
+ "Custom_Func1() CF",
+ "custom_func2() cf",
+ )
+ ]
+ + [
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ )
+ ]
+)
+completers = testdata.get_completers(casing)
+
+
+# Just to make sure that this doesn't crash
+@parametrize("completer", completers())
+def test_function_column_name(completer):
+ for l in range(
+ len("SELECT * FROM Functions WHERE function:"),
+ len("SELECT * FROM Functions WHERE function:text") + 1,
+ ):
+ assert [] == get_result(
+ completer, "SELECT * FROM Functions WHERE function:text"[:l]
+ )
+
+
+@parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"])
+@parametrize("completer", completers())
+def test_drop_alter_function(completer, action):
+ assert get_result(completer, action + " FUNCTION set_ret") == [
+ function("set_returning_func(x integer, y integer)", -len("set_ret"))
+ ]
+
+
+@parametrize("completer", completers())
+def test_empty_string_completion(completer):
+ result = get_result(completer, "")
+ assert completions_to_set(
+ testdata.keywords() + testdata.specials()
+ ) == completions_to_set(result)
+
+
+@parametrize("completer", completers())
+def test_select_keyword_completion(completer):
+ result = get_result(completer, "SEL")
+ assert completions_to_set(result) == completions_to_set([keyword("SELECT", -3)])
+
+
+@parametrize("completer", completers())
+def test_builtin_function_name_completion(completer):
+ result = get_result(completer, "SELECT MA")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("MAKE_DATE", -2),
+ function("MAKE_INTERVAL", -2),
+ function("MAKE_TIME", -2),
+ function("MAKE_TIMESTAMP", -2),
+ function("MAKE_TIMESTAMPTZ", -2),
+ function("MASKLEN", -2),
+ function("MAX", -2),
+ keyword("MAXEXTENTS", -2),
+ keyword("MATERIALIZED VIEW", -2),
+ ]
+ )
+
+
+@parametrize("completer", completers())
+def test_builtin_function_matches_only_at_start(completer):
+ text = "SELECT IN"
+
+ result = [c.text for c in get_result(completer, text)]
+
+ assert "MIN" not in result
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_user_function_name_completion(completer):
+ result = get_result(completer, "SELECT cu")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("custom_fun()", -2),
+ function("_custom_fun()", -2),
+ function("custom_func1()", -2),
+ function("custom_func2()", -2),
+ function("CURRENT_DATE", -2),
+ function("CURRENT_TIMESTAMP", -2),
+ function("CUME_DIST", -2),
+ function("CURRENT_TIME", -2),
+ keyword("CURRENT", -2),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_user_function_name_completion_matches_anywhere(completer):
+ result = get_result(completer, "SELECT om")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("custom_fun()", -2),
+ function("_custom_fun()", -2),
+ function("custom_func1()", -2),
+ function("custom_func2()", -2),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True))
+def test_list_functions_for_special(completer):
+ result = get_result(completer, r"\df ")
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")] + [function(f) for f in cased_func_names]
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggested_column_names_from_visible_table(completer):
+ result = get_result(completer, "SELECT from users", len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(casing=True, qualify=no_qual))
+def test_suggested_cased_column_names(completer):
+ result = get_result(completer, "SELECT from users", len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ cased_funcs
+ + cased_users_cols
+ + testdata.builtin_functions()
+ + testdata.keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+@parametrize("text", ["SELECT from users", "INSERT INTO Orders SELECT from users"])
+def test_suggested_auto_qualified_column_names(text, completer):
+ position = text.index(" ") + 1
+ cols = [column(c.lower()) for c in cased_users_col_names]
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ cols + testdata.functions_and_keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=qual))
+@parametrize(
+ "text",
+ [
+ 'SELECT from users U NATURAL JOIN "Users"',
+ 'INSERT INTO Orders SELECT from users U NATURAL JOIN "Users"',
+ ],
+)
+def test_suggested_auto_qualified_column_names_two_tables(text, completer):
+ position = text.index(" ") + 1
+ cols = [column("U." + c.lower()) for c in cased_users_col_names]
+ cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names]
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ cols + testdata.functions_and_keywords()
+ )
+
+
+@parametrize("completer", completers(casing=True, qualify=["always"]))
+@parametrize("text", ["UPDATE users SET ", "INSERT INTO users("])
+def test_no_column_qualification(text, completer):
+ cols = [column(c) for c in cased_users_col_names]
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(cols)
+
+
+@parametrize("completer", completers(casing=True, qualify=["always"]))
+def test_suggested_cased_always_qualified_column_names(completer):
+ text = "SELECT from users"
+ position = len("SELECT ")
+ cols = [column("users." + c) for c in cased_users_col_names]
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ cased_funcs + cols + testdata.builtin_functions() + testdata.keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggested_column_names_in_function(completer):
+ result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_column_names_with_table_dot(completer):
+ result = get_result(completer, "SELECT users. from users", len("SELECT users."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_column_names_with_alias(completer):
+ result = get_result(completer, "SELECT u. from users u", len("SELECT u."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggested_multiple_column_names(completer):
+ result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_multiple_column_names_with_alias(completer):
+ result = get_result(
+ completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")
+ )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=True))
+def test_suggested_cased_column_names_with_alias(completer):
+ result = get_result(
+ completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")
+ )
+ assert completions_to_set(result) == completions_to_set(cased_users_cols)
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_multiple_column_names_with_dot(completer):
+ result = get_result(
+ completer,
+ "SELECT users.id, users. from users u",
+ len("SELECT users.id, users."),
+ )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_after_three_way_join(completer):
+ text = """SELECT * FROM users u1
+ INNER JOIN users u2 ON u1.id = u2.id
+ INNER JOIN users u3 ON u2.id = u3."""
+ result = get_result(completer, text)
+ assert column("id") in result
+
+
+join_condition_texts = [
+ 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ',
+ """INSERT INTO public.orders(orderid)
+ SELECT * FROM users U JOIN "Users" U2 ON """,
+ 'SELECT * FROM users U JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U INNER join "Users" U2 ON ',
+ 'SELECT * FROM USERS U right JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U LEFT JOIN "Users" U2 ON ',
+ 'SELECT * FROM Users U FULL JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U right outer join "Users" U2 ON ',
+ 'SELECT * FROM Users U LEFT OUTER JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON ',
+ """SELECT *
+ FROM users U
+ FULL OUTER JOIN "Users" U2 ON
+ """,
+]
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize("text", join_condition_texts)
+def test_suggested_join_conditions(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("U"), alias("U2"), fk_join("U2.userid = U.id")]
+ )
+
+
+@parametrize("completer", completers(casing=True))
+@parametrize("text", join_condition_texts)
+def test_cased_join_conditions(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ """SELECT *
+ FROM users
+ CROSS JOIN "Users"
+ NATURAL JOIN users u
+ JOIN "Users" u2 ON
+ """
+ ],
+)
+def test_suggested_join_conditions_with_same_table_twice(completer, text):
+ result = get_result(completer, text)
+ assert result == [
+ fk_join("u2.userid = u.id"),
+ fk_join("u2.userid = users.id"),
+ name_join('u2.userid = "Users".userid'),
+ name_join('u2.username = "Users".username'),
+ alias("u"),
+ alias("u2"),
+ alias("users"),
+ alias('"Users"'),
+ ]
+
+
+@parametrize("completer", completers())
+@parametrize("text", ["SELECT * FROM users JOIN users u2 on foo."])
+def test_suggested_join_conditions_with_invalid_qualifier(completer, text):
+ result = get_result(completer, text)
+ assert result == []
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ ("text", "ref"),
+ [
+ ("SELECT * FROM users JOIN NonTable on ", "NonTable"),
+ ("SELECT * FROM users JOIN nontable nt on ", "nt"),
+ ],
+)
+def test_suggested_join_conditions_with_invalid_table(completer, text, ref):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("users"), alias(ref)]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ [
+ 'SELECT * FROM "Users" u JOIN u',
+ 'SELECT * FROM "Users" u JOIN uid',
+ 'SELECT * FROM "Users" u JOIN userid',
+ 'SELECT * FROM "Users" u JOIN id',
+ ],
+)
+def test_suggested_joins_fuzzy(completer, text):
+ result = get_result(completer, text)
+ last_word = text.split()[-1]
+ expected = join("users ON users.id = u.userid", -len(last_word))
+ assert expected in result
+
+
+join_texts = [
+ "SELECT * FROM Users JOIN ",
+ """INSERT INTO "Users"
+ SELECT *
+ FROM Users
+ INNER JOIN """,
+ """INSERT INTO public."Users"(username)
+ SELECT *
+ FROM Users
+ INNER JOIN """,
+ """SELECT *
+ FROM Users
+ INNER JOIN """,
+]
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize("text", join_texts)
+def test_suggested_joins(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ + [
+ join('"Users" ON "Users".userid = Users.id'),
+ join("users users2 ON users2.id = Users.parentid"),
+ join("users users2 ON users2.parentid = Users.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=False))
+@parametrize("text", join_texts)
+def test_cased_joins(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")]
+ + cased_rels
+ + [
+ join('"Users" ON "Users".UserID = Users.ID'),
+ join("Users Users2 ON Users2.ID = Users.PARENTID"),
+ join("Users Users2 ON Users2.PARENTID = Users.ID"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=True))
+@parametrize("text", join_texts)
+def test_aliased_joins(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas()
+ + aliased_rels
+ + [
+ join('"Users" U ON U.userid = Users.id'),
+ join("users u ON u.id = Users.parentid"),
+ join("users u ON u.parentid = Users.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ [
+ 'SELECT * FROM public."Users" JOIN ',
+ 'SELECT * FROM public."Users" RIGHT OUTER JOIN ',
+ """SELECT *
+ FROM public."Users"
+ LEFT JOIN """,
+ ],
+)
+def test_suggested_joins_quoted_schema_qualified_table(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ + [join('public.users ON users.id = "Users".userid')]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT u.name, o.id FROM users u JOIN orders o ON ",
+ "SELECT u.name, o.id FROM users u JOIN orders o ON JOIN orders o2 ON",
+ ],
+)
+def test_suggested_aliases_after_on(completer, text):
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ alias("u"),
+ name_join("o.id = u.id"),
+ name_join("o.email = u.email"),
+ alias("o"),
+ ]
+ )
+
+
+@parametrize("completer", completers())
+@parametrize(
+ "text",
+ [
+ "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ",
+ "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = JOIN orders o2 ON",
+ ],
+)
+def test_suggested_aliases_after_on_right_side(completer, text):
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set([alias("u"), alias("o")])
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT users.name, orders.id FROM users JOIN orders ON ",
+ "SELECT users.name, orders.id FROM users JOIN orders ON JOIN orders orders2 ON",
+ ],
+)
+def test_suggested_tables_after_on(completer, text):
+ position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ name_join("orders.id = users.id"),
+ name_join("orders.email = users.email"),
+ alias("users"),
+ alias("orders"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = JOIN orders orders2 ON",
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ",
+ ],
+)
+def test_suggested_tables_after_on_right_side(completer, text):
+ position = len(
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
+ )
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("users"), alias("orders")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users INNER JOIN orders USING (",
+ "SELECT * FROM users INNER JOIN orders USING(",
+ ],
+)
+def test_join_using_suggests_common_columns(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [column("id"), column("email")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users u1 JOIN users u2 USING (email) JOIN user_emails ue USING()",
+ "SELECT * FROM users u1 JOIN users u2 USING(email) JOIN user_emails ue USING ()",
+ "SELECT * FROM users u1 JOIN user_emails ue USING () JOIN users u2 ue USING(first_name, last_name)",
+ "SELECT * FROM users u1 JOIN user_emails ue USING() JOIN users u2 ue USING (first_name, last_name)",
+ ],
+)
+def test_join_using_suggests_from_last_table(completer, text):
+ position = text.index("()") + 1
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [column("id"), column("email")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users INNER JOIN orders USING (id,",
+ "SELECT * FROM users INNER JOIN orders USING(id,",
+ ],
+)
+def test_join_using_suggests_columns_after_first_column(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [column("id"), column("email")]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM ",
+ "SELECT * FROM users CROSS JOIN ",
+ "SELECT * FROM users natural join ",
+ ],
+)
+def test_table_names_after_from(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+ assert [c.text for c in result] == [
+ "public",
+ "orders",
+ '"select"',
+ "users",
+ '"Users"',
+ "functions",
+ "user_emails",
+ "_custom_fun()",
+ "custom_fun()",
+ "custom_func1()",
+ "custom_func2()",
+ "set_returning_func(x := , y := )",
+ ]
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_auto_escaped_col_names(completer):
+ result = get_result(completer, 'SELECT from "select"', len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("select")
+ )
+
+
+@parametrize("completer", completers(aliasing=False))
+def test_allow_leading_double_quote_in_last_word(completer):
+ result = get_result(completer, 'SELECT * from "sele')
+
+ expected = table('"select"', -5)
+
+ assert expected in result
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT 1::",
+ "CREATE TABLE foo (bar ",
+ "CREATE FUNCTION foo (bar INT, baz ",
+ "ALTER TABLE foo ALTER COLUMN bar TYPE ",
+ ],
+)
+def test_suggest_datatype(text, completer):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas() + testdata.types() + testdata.builtin_datatypes()
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_from_escaped_table_alias(completer):
+ result = get_result(completer, 'select * from "select" s where s.')
+ assert completions_to_set(result) == completions_to_set(testdata.columns("select"))
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggest_columns_from_set_returning_function(completer):
+ result = get_result(completer, "select from set_returning_func()", len("select "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_from_aliased_set_returning_function(completer):
+ result = get_result(
+ completer, "select f. from set_returning_func() f", len("select f.")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_join_functions_using_suggests_common_columns(completer):
+ text = """SELECT * FROM set_returning_func() f1
+ INNER JOIN set_returning_func() f2 USING ("""
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_join_functions_on_suggests_columns_and_join_conditions(completer):
+ text = """SELECT * FROM set_returning_func() f1
+ INNER JOIN set_returning_func() f2 ON f1."""
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [name_join("y = f2.y"), name_join("x = f2.x")]
+ + testdata.columns("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers())
+def test_learn_keywords(completer):
+ history = "CREATE VIEW v AS SELECT 1"
+ completer.extend_query_history(history)
+
+ # Now that we've used `VIEW` once, it should be suggested ahead of other
+ # keywords starting with v.
+ text = "create v"
+ completions = get_result(completer, text)
+ assert completions[0].text == "VIEW"
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_learn_table_names(completer):
+ history = "SELECT * FROM users; SELECT * FROM orders; SELECT * FROM users"
+ completer.extend_query_history(history)
+
+ text = "SELECT * FROM "
+ completions = get_result(completer, text)
+
+ # `users` should be higher priority than `orders` (used more often)
+ users = table("users")
+ orders = table("orders")
+
+ assert completions.index(users) < completions.index(orders)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_columns_before_keywords(completer):
+ text = "SELECT * FROM orders WHERE s"
+ completions = get_result(completer, text)
+
+ col = column("status", -1)
+ kw = keyword("SELECT", -1)
+
+ assert completions.index(col) < completions.index(kw)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users",
+ "INSERT INTO users SELECT * FROM users u",
+ """INSERT INTO users(id, parentid, email, first_name, last_name)
+ SELECT *
+ FROM users u""",
+ ],
+)
+def test_wildcard_column_expansion(completer, text):
+ position = text.find("*") + 1
+
+ completions = get_result(completer, text, position)
+
+ col_list = "id, parentid, email, first_name, last_name"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT u.* FROM users u",
+ "INSERT INTO public.users SELECT u.* FROM users u",
+ """INSERT INTO users(id, parentid, email, first_name, last_name)
+ SELECT u.*
+ FROM users u""",
+ ],
+)
+def test_wildcard_column_expansion_with_alias(completer, text):
+ position = text.find("*") + 1
+
+ completions = get_result(completer, text, position)
+
+ col_list = "id, u.parentid, u.email, u.first_name, u.last_name"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text,expected",
+ [
+ (
+ "SELECT users.* FROM users",
+ "id, users.parentid, users.email, users.first_name, users.last_name",
+ ),
+ (
+ "SELECT Users.* FROM Users",
+ "id, Users.parentid, Users.email, Users.first_name, Users.last_name",
+ ),
+ ],
+)
+def test_wildcard_column_expansion_with_table_qualifier(completer, text, expected):
+ position = len("SELECT users.*")
+
+ completions = get_result(completer, text, position)
+
+ expected = [wildcard_expansion(expected)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False, qualify=qual))
+def test_wildcard_column_expansion_with_two_tables(completer):
+ text = 'SELECT * FROM "select" JOIN users u ON true'
+ position = len("SELECT *")
+
+ completions = get_result(completer, text, position)
+
+ cols = (
+ '"select".id, "select".insert, "select"."ABC", '
+ "u.id, u.parentid, u.email, u.first_name, u.last_name"
+ )
+ expected = [wildcard_expansion(cols)]
+ assert completions == expected
+
+
+@parametrize("completer", completers(casing=False))
+def test_wildcard_column_expansion_with_two_tables_and_parent(completer):
+ text = 'SELECT "select".* FROM "select" JOIN users u ON true'
+ position = len('SELECT "select".*')
+
+ completions = get_result(completer, text, position)
+
+ col_list = 'id, "select".insert, "select"."ABC"'
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ ["SELECT U. FROM Users U", "SELECT U. FROM USERS U", "SELECT U. FROM users U"],
+)
+def test_suggest_columns_from_unquoted_table(completer, text):
+ position = len("SELECT U.")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_from_quoted_table(completer):
+ result = get_result(completer, 'SELECT U. FROM "Users" U', len("SELECT U."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("Users"))
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "])
+def test_schema_or_visible_table_completion(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=True))
+@parametrize("text", ["SELECT * FROM "])
+def test_table_aliases(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas() + aliased_rels
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=True))
+@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "])
+def test_duplicate_table_aliases(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas()
+ + [
+ table("orders o2"),
+ table("users u"),
+ table('"Users" U'),
+ table('"select" s'),
+ view("user_emails ue"),
+ view("functions f"),
+ function("_custom_fun() cf"),
+ function("custom_fun() cf"),
+ function("custom_func1() cf"),
+ function("custom_func2() cf"),
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ ),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=True))
+@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "])
+def test_duplicate_aliases_with_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ schema("PUBLIC"),
+ table("Orders O2"),
+ table("Users U"),
+ table('"Users" U'),
+ table('"select" s'),
+ view("User_Emails UE"),
+ view("Functions F"),
+ function("_custom_fun() cf"),
+ function("Custom_Fun() CF"),
+ function("Custom_Func1() CF"),
+ function("custom_func2() cf"),
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ ),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=True))
+@parametrize("text", ["SELECT * FROM "])
+def test_aliases_with_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")] + cased_aliased_rels
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=False))
+@parametrize("text", ["SELECT * FROM "])
+def test_table_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")] + cased_rels
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "INSERT INTO users ()",
+ "INSERT INTO users()",
+ "INSERT INTO users () SELECT * FROM orders;",
+ "INSERT INTO users() SELECT * FROM users u cross join orders o",
+ ],
+)
+def test_insert(completer, text):
+ position = text.find("(") + 1
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_suggest_cte_names(completer):
+ text = """
+ WITH cte1 AS (SELECT a, b, c FROM foo),
+ cte2 AS (SELECT d, e, f FROM bar)
+ SELECT * FROM
+ """
+ result = get_result(completer, text)
+ expected = completions_to_set(
+ [
+ Completion("cte1", 0, display_meta="table"),
+ Completion("cte2", 0, display_meta="table"),
+ ]
+ )
+ assert expected <= completions_to_set(result)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggest_columns_from_cte(completer):
+ result = get_result(
+ completer,
+ "WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte",
+ len("WITH cte AS (SELECT foo, bar FROM baz) SELECT "),
+ )
+ expected = [
+ Completion("foo", 0, display_meta="column"),
+ Completion("bar", 0, display_meta="column"),
+ ] + testdata.functions_and_keywords()
+
+ assert completions_to_set(expected) == completions_to_set(result)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.",
+ "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.",
+ ],
+)
+def test_cte_qualified_columns(completer, text):
+ result = get_result(completer, text)
+ expected = [Completion("foo", 0, display_meta="column")]
+ assert completions_to_set(expected) == completions_to_set(result)
+
+
+@parametrize(
+ "keyword_casing,expected,texts",
+ [
+ ("upper", "SELECT", ("", "s", "S", "Sel")),
+ ("lower", "select", ("", "s", "S", "Sel")),
+ ("auto", "SELECT", ("", "S", "SEL", "seL")),
+ ("auto", "select", ("s", "sel", "SEl")),
+ ],
+)
+def test_keyword_casing_upper(keyword_casing, expected, texts):
+ for text in texts:
+ completer = testdata.get_completer({"keyword_casing": keyword_casing})
+ completions = get_result(completer, text)
+ assert expected in [cpl.text for cpl in completions]
+
+
+@parametrize("completer", completers())
+def test_keyword_after_alter(completer):
+ text = "ALTER TABLE users ALTER "
+ expected = Completion("COLUMN", start_position=0, display_meta="keyword")
+ completions = get_result(completer, text)
+ assert expected in completions
+
+
+@parametrize("completer", completers())
+def test_set_schema(completer):
+ text = "SET SCHEMA "
+ result = get_result(completer, text)
+ expected = completions_to_set([schema("'public'")])
+ assert completions_to_set(result) == expected
+
+
+@parametrize("completer", completers())
+def test_special_name_completion(completer):
+ result = get_result(completer, "\\t")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ Completion(
+ text="\\timing",
+ start_position=-2,
+ display_meta="Toggle timing of commands.",
+ )
+ ]
+ )
diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py
new file mode 100644
index 0000000..1034bbe
--- /dev/null
+++ b/tests/test_sqlcompletion.py
@@ -0,0 +1,964 @@
+from pgcli.packages.sqlcompletion import (
+ suggest_type,
+ Special,
+ Database,
+ Schema,
+ Table,
+ Column,
+ View,
+ Keyword,
+ FromClauseItem,
+ Function,
+ Datatype,
+ Alias,
+ JoinCondition,
+ Join,
+)
+from pgcli.packages.parseutils.tables import TableReference
+import pytest
+
+
+def cols_etc(
+ table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None
+):
+ """Returns the expected select-clause suggestions for a single-table
+ select."""
+ return {
+ Column(
+ table_refs=(TableReference(schema, table, alias, is_function),),
+ qualifiable=True,
+ ),
+ Function(schema=parent),
+ Keyword(last_keyword),
+ }
+
+
+def test_select_suggests_cols_with_visible_table_scope():
+ suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
+ assert set(suggestions) == cols_etc("tabl", last_keyword="SELECT")
+
+
+def test_select_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
+ assert set(suggestions) == cols_etc("tabl", "sch", last_keyword="SELECT")
+
+
+def test_cte_does_not_crash():
+ sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;"
+ for i in range(len(sql)):
+ suggestions = suggest_type(sql[: i + 1], sql[: i + 1])
+
+
+@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE '])
+def test_where_suggests_columns_functions_quoted_table(expression):
+ expected = cols_etc("tabl", alias='"tabl"', last_keyword="WHERE")
+ suggestions = suggest_type(expression, expression)
+ assert expected == set(suggestions)
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ",
+ "INSERT INTO OtherTabl SELECT * FROM tabl WHERE ",
+ "SELECT * FROM tabl WHERE ",
+ "SELECT * FROM tabl WHERE (",
+ "SELECT * FROM tabl WHERE foo = ",
+ "SELECT * FROM tabl WHERE bar OR ",
+ "SELECT * FROM tabl WHERE foo = 1 AND ",
+ "SELECT * FROM tabl WHERE (bar > 10 AND ",
+ "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
+ "SELECT * FROM tabl WHERE 10 < ",
+ "SELECT * FROM tabl WHERE foo BETWEEN ",
+ "SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
+ ],
+)
+def test_where_suggests_columns_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+@pytest.mark.parametrize(
+ "expression",
+ ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "],
+)
+def test_where_in_suggests_columns(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "])
+def test_after_as(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == set()
+
+
+def test_where_equals_any_suggests_columns_or_keywords():
+ text = "SELECT * FROM tabl WHERE foo = ANY("
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+def test_lparen_suggests_cols_and_funcs():
+ suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
+ assert set(suggestion) == {
+ Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("("),
+ }
+
+
+def test_select_suggests_cols_and_funcs():
+ suggestions = suggest_type("SELECT ", "SELECT ")
+ assert set(suggestions) == {
+ Column(table_refs=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "]
+)
+def test_suggests_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()}
+
+
+@pytest.mark.parametrize("expression", ["SELECT * FROM "])
+def test_suggest_tables_views_schemas_and_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ",
+ "SELECT * FROM foo JOIN bar USING (barid) JOIN ",
+ ],
+)
+def test_suggest_after_join_with_two_tables(expression):
+ suggestions = suggest_type(expression, expression)
+ tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
+ assert set(suggestions) == {
+ FromClauseItem(schema=None, table_refs=tables),
+ Join(tables, None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"]
+)
+def test_suggest_after_join_with_one_table(expression):
+ suggestions = suggest_type(expression, expression)
+ tables = ((None, "foo", None, False),)
+ assert set(suggestions) == {
+ FromClauseItem(schema=None, table_refs=tables),
+ Join(((None, "foo", None, False),), None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."]
+)
+def test_suggest_qualified_tables_and_views(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
+
+
+@pytest.mark.parametrize("expression", ["UPDATE sch."])
+def test_suggest_qualified_aliasable_tables_and_views(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM sch.",
+ 'SELECT * FROM sch."',
+ 'SELECT * FROM sch."foo',
+ 'SELECT * FROM "sch".',
+ 'SELECT * FROM "sch"."',
+ ],
+)
+def test_suggest_qualified_tables_views_and_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {FromClauseItem(schema="sch")}
+
+
+@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
+def test_suggest_qualified_tables_views_functions_and_joins(expression):
+ suggestions = suggest_type(expression, expression)
+ tbls = tuple([(None, "foo", None, False)])
+ assert set(suggestions) == {
+ FromClauseItem(schema="sch", table_refs=tbls),
+ Join(tbls, "sch"),
+ }
+
+
+def test_truncate_suggests_tables_and_schemas():
+ suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
+ assert set(suggestions) == {Table(schema=None), Schema()}
+
+
+def test_truncate_suggests_qualified_tables():
+ suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
+ assert set(suggestions) == {Table(schema="sch")}
+
+
+@pytest.mark.parametrize(
+ "text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "]
+)
+def test_distinct_suggests_cols(text):
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == {
+ Column(table_refs=(), local_tables=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("DISTINCT"),
+ }
+
+
+@pytest.mark.parametrize(
+ "text, text_before, last_keyword",
+ [
+ ("SELECT DISTINCT FROM tbl x JOIN tbl1 y", "SELECT DISTINCT", "SELECT"),
+ (
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ",
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ",
+ "ORDER BY",
+ ),
+ ],
+)
+def test_distinct_and_order_by_suggestions_with_aliases(
+ text, text_before, last_keyword
+):
+ suggestions = suggest_type(text, text_before)
+ assert set(suggestions) == {
+ Column(
+ table_refs=(
+ TableReference(None, "tbl", "x", False),
+ TableReference(None, "tbl1", "y", False),
+ ),
+ local_tables=(),
+ qualifiable=True,
+ ),
+ Function(schema=None),
+ Keyword(last_keyword),
+ }
+
+
+@pytest.mark.parametrize(
+ "text, text_before",
+ [
+ ("SELECT DISTINCT x. FROM tbl x JOIN tbl1 y", "SELECT DISTINCT x."),
+ (
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.",
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.",
+ ),
+ ],
+)
+def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
+ suggestions = suggest_type(text, text_before)
+ assert set(suggestions) == {
+ Column(
+ table_refs=(TableReference(None, "tbl", "x", False),),
+ local_tables=(),
+ qualifiable=False,
+ ),
+ Table(schema="x"),
+ View(schema="x"),
+ Function(schema="x"),
+ }
+
+
+def test_function_arguments_with_alias_given():
+ suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
+
+ assert set(suggestions) == {
+ Column(
+ table_refs=(TableReference(None, "tbl", "x", False),),
+ local_tables=(),
+ qualifiable=False,
+ ),
+ Table(schema="x"),
+ View(schema="x"),
+ Function(schema="x"),
+ }
+
+
+def test_col_comma_suggests_cols():
+ suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+def test_table_comma_suggests_tables_and_schemas():
+ suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+
+def test_into_suggests_tables_and_schemas():
+ suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
+ assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()}
+
+
+@pytest.mark.parametrize(
+ "text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"]
+)
+def test_insert_into_lparen_suggests_cols(text):
+ suggestions = suggest_type(text, "INSERT INTO abc (")
+ assert suggestions == (
+ Column(table_refs=((None, "abc", None, False),), context="insert"),
+ )
+
+
+def test_insert_into_lparen_partial_text_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
+ assert suggestions == (
+ Column(table_refs=((None, "abc", None, False),), context="insert"),
+ )
+
+
+def test_insert_into_lparen_comma_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
+ assert suggestions == (
+ Column(table_refs=((None, "abc", None, False),), context="insert"),
+ )
+
+
+def test_partially_typed_col_name_suggests_col_names():
+ suggestions = suggest_type(
+ "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n"
+ )
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
+ suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl", None, False),)),
+ Table(schema="tabl"),
+ View(schema="tabl"),
+ Function(schema="tabl"),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT t1. FROM tabl1 t1",
+ "SELECT t1. FROM tabl1 t1, tabl2 t2",
+ 'SELECT t1. FROM "tabl1" t1',
+ 'SELECT t1. FROM "tabl1" t1, "tabl2" t2',
+ ],
+)
+def test_dot_suggests_cols_of_an_alias(sql):
+ suggestions = suggest_type(sql, "SELECT t1.")
+ assert set(suggestions) == {
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM tabl1 t1 WHERE t1.",
+ "SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.",
+ 'SELECT * FROM "tabl1" t1 WHERE t1.',
+ 'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.',
+ ],
+)
+def test_dot_suggests_cols_of_an_alias_where(sql):
+ suggestions = suggest_type(sql, sql)
+ assert set(suggestions) == {
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ }
+
+
+def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
+ suggestions = suggest_type(
+ "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl2", "t2", False),)),
+ Table(schema="t2"),
+ View(schema="t2"),
+ Function(schema="t2"),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (",
+ "SELECT * FROM foo WHERE EXISTS (",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (",
+ ],
+)
+def test_sub_select_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == (Keyword(),)
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (S",
+ "SELECT * FROM foo WHERE EXISTS (S",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
+ ],
+)
+def test_sub_select_partial_text_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == (Keyword(),)
+
+
+def test_outer_table_reference_in_exists_subquery_suggests_columns():
+ q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
+ suggestions = suggest_type(q, q)
+ assert set(suggestions) == {
+ Column(table_refs=((None, "foo", "f", False),)),
+ Table(schema="f"),
+ View(schema="f"),
+ Function(schema="f"),
+ }
+
+
+@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
+def test_sub_select_table_name_completion(expression):
+ suggestion = suggest_type(expression, expression)
+ assert set(suggestion) == {FromClauseItem(schema=None), Schema()}
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
+ ],
+)
+def test_sub_select_table_name_completion_with_outer_table(expression):
+ suggestion = suggest_type(expression, expression)
+ tbls = tuple([(None, "foo", None, False)])
+ assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
+
+
+def test_sub_select_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "abc", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+@pytest.mark.xfail
+def test_sub_select_multiple_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, "
+ )
+ assert set(suggestions) == cols_etc("abc")
+
+
+def test_sub_select_dot_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl", "t", False),)),
+ Table(schema="t"),
+ View(schema="t"),
+ Function(schema="t"),
+ }
+
+
+@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
+@pytest.mark.parametrize("tbl_alias", ("", "foo"))
+def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
+ text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN "
+ suggestion = suggest_type(text, text)
+ tbls = tuple([(None, "abc", tbl_alias or None, False)])
+ assert set(suggestion) == {
+ FromClauseItem(schema=None, table_refs=tbls),
+ Schema(),
+ Join(tbls, None),
+ }
+
+
+def test_left_join_with_comma():
+ text = "select * from foo f left join bar b,"
+ suggestions = suggest_type(text, text)
+ # tbls should also include (None, 'bar', 'b', False)
+ # but there's a bug with commas
+ tbls = tuple([(None, "foo", "f", False)])
+ assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
+ ],
+)
+def test_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", "a", False), (None, "def", "d", False))
+ assert set(suggestions) == {
+ Column(table_refs=((None, "abc", "a", False),)),
+ Table(schema="a"),
+ View(schema="a"),
+ Function(schema="a"),
+ JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.id = d.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
+ ],
+)
+def test_join_alias_dot_suggests_cols2(sql):
+ suggestion = suggest_type(sql, sql)
+ assert set(suggestion) == {
+ Column(table_refs=((None, "def", "d", False),)),
+ Table(schema="d"),
+ View(schema="d"),
+ Function(schema="d"),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on ",
+ """select a.x, b.y
+from abc a
+join bcd b on
+""",
+ """select a.x, b.y
+from abc a
+join bcd b
+on """,
+ "select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
+ ],
+)
+def test_on_suggests_aliases_and_join_conditions(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("a", "b")),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
+ "select abc.x, bcd.y from abc join bcd on ",
+ ],
+)
+def test_on_suggests_tables_and_join_conditions(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", None, False), (None, "bcd", None, False))
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("abc", "bcd")),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on a.id = ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
+ ],
+)
+def test_on_suggests_aliases_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == (Alias(aliases=("a", "b")),)
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
+ "select abc.x, bcd.y from abc join bcd on ",
+ ],
+)
+def test_on_suggests_tables_and_join_conditions_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", None, False), (None, "bcd", None, False))
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("abc", "bcd")),
+ }
+
+
+@pytest.mark.parametrize(
+ "text",
+ (
+ "select * from abc inner join def using (",
+ "select * from abc inner join def using (col1, ",
+ "insert into hij select * from abc inner join def using (",
+ """insert into hij(x, y, z)
+ select * from abc inner join def using (col1, """,
+ """insert into hij (a,b,c)
+ select * from abc inner join def using (col1, """,
+ ),
+)
+def test_join_using_suggests_common_columns(text):
+ tables = ((None, "abc", None, False), (None, "def", None, False))
+ assert set(suggest_type(text, text)) == {
+ Column(table_refs=tables, require_last_table=True)
+ }
+
+
+def test_suggest_columns_after_multiple_joins():
+ sql = """select * from t1
+ inner join t2 ON
+ t1.id = t2.t1_id
+ inner join t3 ON
+ t2.id = t3."""
+ suggestions = suggest_type(sql, sql)
+ assert Column(table_refs=((None, "t3", None, False),)) in set(suggestions)
+
+
+def test_2_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ", "select * from a; select * from "
+ )
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+ suggestions = suggest_type(
+ "select * from a; select from b", "select * from a; select "
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "b", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+ # Should work even if first statement is invalid
+ suggestions = suggest_type(
+ "select * from; select * from ", "select * from; select * from "
+ )
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+
+def test_2_statements_1st_current():
+ suggestions = suggest_type("select * from ; select * from b", "select * from ")
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+ suggestions = suggest_type("select from a; select * from b", "select ")
+ assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
+
+
+def test_3_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ; select * from c",
+ "select * from a; select * from ",
+ )
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+ suggestions = suggest_type(
+ "select * from a; select from b; select * from c", "select * from a; select "
+ )
+ assert set(suggestions) == cols_etc("b", last_keyword="SELECT")
+
+
+@pytest.mark.parametrize(
+ "text",
+ [
+ """
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
+SELECT FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+ """,
+ """create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 2 FROM bar;
+SELECT FROM foo;
+$func$
+ """,
+ """
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
+SELECT 3 FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 2 FROM bar;
+SELECT FROM foo;
+$func$
+ """,
+ """
+SELECT * FROM baz;
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
+SELECT FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 3 FROM bar;
+SELECT FROM foo;
+$func$
+SELECT * FROM qux;
+ """,
+ ],
+)
+def test_statements_in_function_body(text):
+ suggestions = suggest_type(text, text[: text.find(" ") + 1])
+ assert set(suggestions) == {
+ Column(table_refs=((None, "foo", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+functions = [
+ """
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
+SELECT 1 FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+ """,
+ """
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+'
+SELECT 2 FROM bar;
+SELECT 1 FROM foo;
+';
+ """,
+]
+
+
+@pytest.mark.parametrize("text", functions)
+def test_statements_with_cursor_after_function_body(text):
+ suggestions = suggest_type(text, text[: text.find("; ") + 1])
+ assert set(suggestions) == {Keyword(), Special()}
+
+
+@pytest.mark.parametrize("text", functions)
+def test_statements_with_cursor_before_function_body(text):
+ suggestions = suggest_type(text, "")
+ assert set(suggestions) == {Keyword(), Special()}
+
+
+def test_create_db_with_template():
+ suggestions = suggest_type(
+ "create database foo with template ", "create database foo with template "
+ )
+
+ assert set(suggestions) == {Database()}
+
+
+@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
+def test_specials_included_for_initial_completion(initial_text):
+ suggestions = suggest_type(initial_text, initial_text)
+
+ assert set(suggestions) == {Keyword(), Special()}
+
+
+def test_drop_schema_qualified_table_suggests_only_tables():
+ text = "DROP TABLE schema_name.table_name"
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Table(schema="schema_name"),)
+
+
+@pytest.mark.parametrize("text", (",", " ,", "sel ,"))
+def test_handle_pre_completion_comma_gracefully(text):
+ suggestions = suggest_type(text, text)
+
+ assert iter(suggestions)
+
+
+def test_drop_schema_suggests_schemas():
+ sql = "DROP SCHEMA "
+ assert suggest_type(sql, sql) == (Schema(),)
+
+
+@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
+def test_cast_operator_suggests_types(text):
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
+)
+def test_cast_operator_suggests_schema_qualified_types(text):
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema="bar"),
+ Table(schema="bar"),
+ }
+
+
+def test_alter_column_type_suggests_types():
+ q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
+ assert set(suggest_type(q, q)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "text",
+ [
+ "CREATE TABLE foo (bar ",
+ "CREATE TABLE foo (bar DOU",
+ "CREATE TABLE foo (bar INT, baz ",
+ "CREATE TABLE foo (bar INT, baz TEXT, qux ",
+ "CREATE FUNCTION foo (bar ",
+ "CREATE FUNCTION foo (bar INT, baz ",
+ "SELECT * FROM foo() AS bar (baz ",
+ "SELECT * FROM foo() AS bar (baz INT, qux ",
+ # make sure this doesn't trigger special completion
+ "CREATE TABLE foo (dt d",
+ ],
+)
+def test_identifier_suggests_types_in_parentheses(text):
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "text",
+ [
+ "SELECT foo ",
+ "SELECT foo FROM bar ",
+ "SELECT foo AS bar ",
+ "SELECT foo bar ",
+ "SELECT * FROM foo AS bar ",
+ "SELECT * FROM foo bar ",
+ "SELECT foo FROM (SELECT bar ",
+ ],
+)
+def test_alias_suggests_keywords(text):
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Keyword(),)
+
+
+def test_invalid_sql():
+ # issue 317
+ text = "selt *"
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Keyword(),)
+
+
+@pytest.mark.parametrize(
+ "text",
+ ["SELECT * FROM foo where created > now() - ", "select * from foo where bar "],
+)
+def test_suggest_where_keyword(text):
+ # https://github.com/dbcli/mycli/issues/135
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == cols_etc("foo", last_keyword="WHERE")
+
+
+@pytest.mark.parametrize(
+ "text, before, expected",
+ [
+ (
+ "\\ns abc SELECT ",
+ "SELECT ",
+ [
+ Column(table_refs=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ ],
+ ),
+ ("\\ns abc SELECT foo ", "SELECT foo ", (Keyword(),)),
+ (
+ "\\ns abc SELECT t1. FROM tabl1 t1",
+ "SELECT t1.",
+ [
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ ],
+ ),
+ ],
+)
+def test_named_query_completion(text, before, expected):
+ suggestions = suggest_type(text, before)
+ assert set(expected) == set(suggestions)
+
+
+def test_select_suggests_fields_from_function():
+ suggestions = suggest_type("SELECT FROM func()", "SELECT ")
+ assert set(suggestions) == cols_etc("func", is_function=True, last_keyword="SELECT")
+
+
+@pytest.mark.parametrize("sql", ["("])
+def test_leading_parenthesis(sql):
+ # No assertion for now; just make sure it doesn't crash
+ suggest_type(sql, sql)
+
+
+@pytest.mark.parametrize("sql", ['select * from "', 'select * from "foo'])
+def test_ignore_leading_double_quotes(sql):
+ suggestions = suggest_type(sql, sql)
+ assert FromClauseItem(schema=None) in set(suggestions)
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "ALTER TABLE foo ALTER COLUMN ",
+ "ALTER TABLE foo ALTER COLUMN bar",
+ "ALTER TABLE foo DROP COLUMN ",
+ "ALTER TABLE foo DROP COLUMN bar",
+ ],
+)
+def test_column_keyword_suggests_columns(sql):
+ suggestions = suggest_type(sql, sql)
+ assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))}
+
+
+def test_handle_unrecognized_kw_generously():
+ sql = "SELECT * FROM sessions WHERE session = 1 AND "
+ suggestions = suggest_type(sql, sql)
+ expected = Column(table_refs=((None, "sessions", None, False),), qualifiable=True)
+
+ assert expected in set(suggestions)
+
+
+@pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "])
+def test_keyword_after_alter(sql):
+ assert Keyword("ALTER") in set(suggest_type(sql, sql))
diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py
new file mode 100644
index 0000000..ae865f4
--- /dev/null
+++ b/tests/test_ssh_tunnel.py
@@ -0,0 +1,188 @@
+import os
+from unittest.mock import patch, MagicMock, ANY
+
+import pytest
+from configobj import ConfigObj
+from click.testing import CliRunner
+from sshtunnel import SSHTunnelForwarder
+
+from pgcli.main import cli, PGCli
+from pgcli.pgexecute import PGExecute
+
+
+@pytest.fixture
+def mock_ssh_tunnel_forwarder() -> MagicMock:
+ mock_ssh_tunnel_forwarder = MagicMock(
+ SSHTunnelForwarder, local_bind_ports=[1111], autospec=True
+ )
+ with patch(
+ "pgcli.main.sshtunnel.SSHTunnelForwarder",
+ return_value=mock_ssh_tunnel_forwarder,
+ ) as mock:
+ yield mock
+
+
+@pytest.fixture
+def mock_pgexecute() -> MagicMock:
+ with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute:
+ yield mock_pgexecute
+
+
+def test_ssh_tunnel(
+ mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ # Test with just a host
+ tunnel_url = "some.host"
+ db_params = {
+ "database": "dbname",
+ "host": "db.host",
+ "user": "db_user",
+ "passwd": "db_passwd",
+ }
+ expected_tunnel_params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (db_params["host"], 5432),
+ "ssh_address_or_host": (tunnel_url, 22),
+ "logger": ANY,
+ }
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with a full url and with a specific db port
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "some.other.host"
+ tunnel_port = 1022
+ tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+ db_params["port"] = 1234
+
+ expected_tunnel_params["remote_bind_address"] = (
+ db_params["host"],
+ db_params["port"],
+ )
+ expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port)
+ expected_tunnel_params["ssh_username"] = tunnel_user
+ expected_tunnel_params["ssh_password"] = tunnel_passwd
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with DSN
+ dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host={db_params['host']} port={db_params['port']}"
+ )
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(dsn=dsn)
+
+ expected_dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
+ )
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert expected_dsn in call_args
+
+
+def test_cli_with_tunnel() -> None:
+ runner = CliRunner()
+ tunnel_url = "mytunnel"
+ with patch.object(
+ PGCli, "__init__", autospec=True, return_value=None
+ ) as mock_pgcli:
+ runner.invoke(cli, ["--ssh-tunnel", tunnel_url])
+ mock_pgcli.assert_called_once()
+ call_args, call_kwargs = mock_pgcli.call_args
+ assert call_kwargs["ssh_tunnel_url"] == tunnel_url
+
+
+def test_config(
+ tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ pgclirc = str(tmpdir.join("rcfile"))
+
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "tunnel.host"
+ tunnel_port = 1022
+ tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+
+ tunnel2_url = "tunnel2.host"
+
+ config = ConfigObj()
+ config.filename = pgclirc
+ config["ssh tunnels"] = {}
+ config["ssh tunnels"][r"\.com$"] = tunnel_url
+ config["ssh tunnels"][r"^hello-"] = tunnel2_url
+ config.write()
+
+ # Unmatched host
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="unmatched.host")
+ mock_ssh_tunnel_forwarder.assert_not_called()
+
+ # Host matching first tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="matched.host.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching second tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22)
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching both tunnels (will use the first one matched)
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..67d769f
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,92 @@
+import pytest
+import psycopg
+from pgcli.main import format_output, OutputSettings
+from os import getenv
+
+POSTGRES_USER = getenv("PGUSER", "postgres")
+POSTGRES_HOST = getenv("PGHOST", "localhost")
+POSTGRES_PORT = getenv("PGPORT", 5432)
+POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres")
+
+
+def db_connection(dbname=None):
+ conn = psycopg.connect(
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ port=POSTGRES_PORT,
+ dbname=dbname,
+ )
+ conn.autocommit = True
+ return conn
+
+
+try:
+ conn = db_connection()
+ CAN_CONNECT_TO_DB = True
+ SERVER_VERSION = conn.info.parameter_status("server_version")
+ JSON_AVAILABLE = True
+ JSONB_AVAILABLE = True
+except Exception as x:
+ CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
+ SERVER_VERSION = 0
+
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB,
+ reason="Need a postgres instance at localhost accessible by user 'postgres'",
+)
+
+
+requires_json = pytest.mark.skipif(
+ not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
+)
+
+
+requires_jsonb = pytest.mark.skipif(
+ not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
+)
+
+
+def create_db(dbname):
+ with db_connection().cursor() as cur:
+ try:
+ cur.execute("""CREATE DATABASE _test_db""")
+ except:
+ pass
+
+
+def drop_tables(conn):
+ with conn.cursor() as cur:
+ cur.execute(
+ """
+ DROP SCHEMA public CASCADE;
+ CREATE SCHEMA public;
+ DROP SCHEMA IF EXISTS schema1 CASCADE;
+ DROP SCHEMA IF EXISTS schema2 CASCADE"""
+ )
+
+
+def run(
+ executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
+):
+ "Return string output for the sql to be run"
+
+ results = executor.run(sql, pgspecial, exception_formatter)
+ formatted = []
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
+ )
+ for title, rows, headers, status, sql, success, is_special in results:
+ formatted.extend(format_output(title, rows, headers, status, settings))
+ if join:
+ formatted = "\n".join(formatted)
+
+ return formatted
+
+
+def completions_to_set(completions):
+ return {
+ (completion.display_text, completion.display_meta_text)
+ for completion in completions
+ }
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 0000000..7d33698
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,14 @@
+[tox]
+envlist = py38, py39, py310, py311, py312
+[testenv]
+deps = pytest>=2.7.0,<=3.0.7
+ mock>=1.0.1
+ behave>=1.2.4
+ pexpect==3.3
+ sshtunnel>=0.4.0
+commands = py.test
+ behave tests/features
+passenv = PGHOST
+ PGPORT
+ PGUSER
+ PGPASSWORD